karyon_p2p/discovery/
refresh.rs

1use std::{sync::Arc, time::Duration};
2
3use bincode::{Decode, Encode};
4use log::{error, info, trace};
5use parking_lot::RwLock;
6use rand::{rngs::OsRng, RngCore};
7
8use karyon_core::{
9    async_runtime::Executor,
10    async_util::{sleep, timeout, Backoff, TaskGroup, TaskResult},
11};
12
13use karyon_net::{udp, Connection, Endpoint};
14
15use crate::{
16    codec::RefreshMsgCodec,
17    message::RefreshMsg,
18    monitor::{ConnEvent, DiscvEvent, Monitor},
19    routing_table::{BucketEntry, Entry, RoutingTable, PENDING_ENTRY, UNREACHABLE_ENTRY},
20    Config, Error, Result,
21};
22
23/// Maximum failures for an entry before removing it from the routing table.
24pub const MAX_FAILURES: u32 = 3;
25
26#[derive(Decode, Encode, Debug, Clone)]
27pub struct PingMsg(pub [u8; 32]);
28
29#[derive(Decode, Encode, Debug)]
30pub struct PongMsg(pub [u8; 32]);
31
32pub struct RefreshService {
33    /// Routing table
34    table: Arc<RoutingTable>,
35
36    /// Resolved listen endpoint
37    listen_endpoint: RwLock<Option<Endpoint>>,
38
39    /// Managing spawned tasks.
40    task_group: TaskGroup,
41
42    /// Holds the configuration for the P2P network.
43    config: Arc<Config>,
44
45    /// Responsible for network and system monitoring.
46    monitor: Arc<Monitor>,
47}
48
49impl RefreshService {
50    /// Creates a new refresh service
51    pub fn new(
52        config: Arc<Config>,
53        table: Arc<RoutingTable>,
54        monitor: Arc<Monitor>,
55        executor: Executor,
56    ) -> Self {
57        Self {
58            table,
59            listen_endpoint: RwLock::new(None),
60            task_group: TaskGroup::with_executor(executor.clone()),
61            config,
62            monitor,
63        }
64    }
65
66    /// Start the refresh service
67    pub async fn start(self: &Arc<Self>) -> Result<()> {
68        if let Some(endpoint) = self.listen_endpoint.read().as_ref() {
69            let endpoint = endpoint.clone();
70            self.task_group.spawn(
71                {
72                    let this = self.clone();
73                    async move { this.listen_loop(endpoint).await }
74                },
75                |res| async move {
76                    if let TaskResult::Completed(Err(err)) = res {
77                        error!("Listen loop stopped: {err}");
78                    }
79                },
80            );
81        }
82
83        self.task_group.spawn(
84            {
85                let this = self.clone();
86                async move { this.refresh_loop().await }
87            },
88            |res| async move {
89                if let TaskResult::Completed(Err(err)) = res {
90                    error!("Refresh loop stopped: {err}");
91                }
92            },
93        );
94
95        Ok(())
96    }
97
98    /// Set the resolved listen endpoint.
99    pub fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) -> Result<()> {
100        let resolved_endpoint = Endpoint::Udp(
101            resolved_endpoint.addr()?.clone(),
102            self.config.discovery_port,
103        );
104        *self.listen_endpoint.write() = Some(resolved_endpoint);
105        Ok(())
106    }
107
108    /// Shuts down the refresh service
109    pub async fn shutdown(&self) {
110        self.task_group.cancel().await;
111    }
112
113    /// Initiates periodic refreshing of the routing table. This function will
114    /// selects the first 8 entries (oldest entries) from each bucket in the
115    /// routing table and starts sending Ping messages to the collected entries.
116    async fn refresh_loop(self: Arc<Self>) -> Result<()> {
117        loop {
118            sleep(Duration::from_secs(self.config.refresh_interval)).await;
119            trace!("Start refreshing the routing table...");
120
121            self.monitor.notify(DiscvEvent::RefreshStarted).await;
122
123            let mut entries: Vec<BucketEntry> = vec![];
124            for bucket in self.table.buckets() {
125                for entry in bucket
126                    .iter()
127                    .filter(|e| !e.is_connected() && !e.is_incompatible())
128                    .take(8)
129                {
130                    entries.push(entry.clone())
131                }
132            }
133
134            self.clone().do_refresh(&entries).await;
135        }
136    }
137
138    /// Iterates over the entries and initiates a connection.
139    async fn do_refresh(self: Arc<Self>, entries: &[BucketEntry]) {
140        use futures_util::stream::{FuturesUnordered, StreamExt};
141        // Enforce a maximum of 16 connections.
142        for chunk in entries.chunks(16) {
143            let mut tasks = FuturesUnordered::new();
144            for bucket_entry in chunk {
145                if bucket_entry.failures >= MAX_FAILURES {
146                    self.table.remove_entry(&bucket_entry.entry.key);
147                    continue;
148                }
149
150                tasks.push(self.clone().refresh_entry(bucket_entry.clone()))
151            }
152
153            while tasks.next().await.is_some() {}
154        }
155    }
156
157    /// Initiates refresh for a specific entry within the routing table. It
158    /// updates the routing table according to the result.
159    async fn refresh_entry(self: Arc<Self>, bucket_entry: BucketEntry) {
160        let key = &bucket_entry.entry.key;
161        match self.connect(&bucket_entry.entry).await {
162            Ok(_) => {
163                self.table.update_entry(key, PENDING_ENTRY);
164            }
165            Err(err) => {
166                trace!("Failed to refresh entry {:?}: {err}", key);
167                if bucket_entry.failures >= MAX_FAILURES {
168                    self.table.remove_entry(key);
169                    return;
170                }
171                self.table.update_entry(key, UNREACHABLE_ENTRY);
172            }
173        }
174    }
175
176    /// Initiates a UDP connection with the entry and attempts to send a Ping
177    /// message. If it fails, it retries according to the allowed retries
178    /// specified in the Config, with backoff between each retry.
179    async fn connect(&self, entry: &Entry) -> Result<()> {
180        let mut retry = 0;
181        let endpoint = Endpoint::Udp(entry.addr.clone(), entry.discovery_port);
182        let conn = udp::dial(&endpoint, Default::default(), RefreshMsgCodec {}).await?;
183        let backoff = Backoff::new(100, 5000);
184        while retry < self.config.refresh_connect_retries {
185            match self.send_ping_msg(&conn, &endpoint).await {
186                Ok(()) => return Ok(()),
187                Err(Error::Timeout) => {
188                    retry += 1;
189                    backoff.sleep().await;
190                }
191                Err(err) => {
192                    return Err(err);
193                }
194            }
195        }
196
197        Err(Error::Timeout)
198    }
199
200    /// Set up a UDP listener and start listening for Ping messages from other
201    /// peers.
202    async fn listen_loop(self: Arc<Self>, endpoint: Endpoint) -> Result<()> {
203        let conn = match udp::listen(&endpoint, Default::default(), RefreshMsgCodec {}).await {
204            Ok(c) => {
205                self.monitor
206                    .notify(ConnEvent::Listening(endpoint.clone()))
207                    .await;
208                c
209            }
210            Err(err) => {
211                self.monitor
212                    .notify(ConnEvent::ListenFailed(endpoint.clone()))
213                    .await;
214                return Err(err.into());
215            }
216        };
217        info!("Start listening on {endpoint}");
218
219        loop {
220            let res = self.listen_to_ping_msg(&conn).await;
221            if let Err(err) = res {
222                trace!("Failed to handle ping msg {err}");
223                self.monitor.notify(ConnEvent::AcceptFailed).await;
224            }
225        }
226    }
227
228    /// Listen to receive a Ping message and respond with a Pong message.
229    async fn listen_to_ping_msg(&self, conn: &udp::UdpConn<RefreshMsgCodec>) -> Result<()> {
230        let (msg, endpoint) = conn.recv().await?;
231        self.monitor
232            .notify(ConnEvent::Accepted(endpoint.clone()))
233            .await;
234
235        match msg {
236            RefreshMsg::Ping(m) => {
237                let pong_msg = RefreshMsg::Pong(m);
238                conn.send((pong_msg, endpoint.clone())).await?;
239            }
240            RefreshMsg::Pong(_) => return Err(Error::InvalidMsg("Unexpected pong msg".into())),
241        }
242
243        self.monitor.notify(ConnEvent::Disconnected(endpoint)).await;
244        Ok(())
245    }
246
247    /// Sends a Ping msg and wait to receive the Pong message.
248    async fn send_ping_msg(
249        &self,
250        conn: &udp::UdpConn<RefreshMsgCodec>,
251        endpoint: &Endpoint,
252    ) -> Result<()> {
253        let mut nonce: [u8; 32] = [0; 32];
254        RngCore::fill_bytes(&mut OsRng, &mut nonce);
255        conn.send((RefreshMsg::Ping(nonce), endpoint.clone()))
256            .await?;
257
258        let t = Duration::from_secs(self.config.refresh_response_timeout);
259        let (msg, _) = timeout(t, conn.recv()).await??;
260
261        match msg {
262            RefreshMsg::Pong(n) => {
263                if n != nonce {
264                    return Err(Error::InvalidPongMsg);
265                }
266                Ok(())
267            }
268            _ => Err(Error::InvalidMsg("Unexpected ping msg".into())),
269        }
270    }
271}