karyon_p2p/discovery/
refresh.rs1use std::{sync::Arc, time::Duration};
2
3use bincode::{Decode, Encode};
4use log::{error, info, trace};
5use parking_lot::RwLock;
6use rand::{rngs::OsRng, RngCore};
7
8use karyon_core::{
9 async_runtime::Executor,
10 async_util::{sleep, timeout, Backoff, TaskGroup, TaskResult},
11};
12
13use karyon_net::{udp, Connection, Endpoint};
14
15use crate::{
16 codec::RefreshMsgCodec,
17 message::RefreshMsg,
18 monitor::{ConnEvent, DiscvEvent, Monitor},
19 routing_table::{BucketEntry, Entry, RoutingTable, PENDING_ENTRY, UNREACHABLE_ENTRY},
20 Config, Error, Result,
21};
22
23pub const MAX_FAILURES: u32 = 3;
25
26#[derive(Decode, Encode, Debug, Clone)]
27pub struct PingMsg(pub [u8; 32]);
28
29#[derive(Decode, Encode, Debug)]
30pub struct PongMsg(pub [u8; 32]);
31
32pub struct RefreshService {
33 table: Arc<RoutingTable>,
35
36 listen_endpoint: RwLock<Option<Endpoint>>,
38
39 task_group: TaskGroup,
41
42 config: Arc<Config>,
44
45 monitor: Arc<Monitor>,
47}
48
49impl RefreshService {
50 pub fn new(
52 config: Arc<Config>,
53 table: Arc<RoutingTable>,
54 monitor: Arc<Monitor>,
55 executor: Executor,
56 ) -> Self {
57 Self {
58 table,
59 listen_endpoint: RwLock::new(None),
60 task_group: TaskGroup::with_executor(executor.clone()),
61 config,
62 monitor,
63 }
64 }
65
66 pub async fn start(self: &Arc<Self>) -> Result<()> {
68 if let Some(endpoint) = self.listen_endpoint.read().as_ref() {
69 let endpoint = endpoint.clone();
70 self.task_group.spawn(
71 {
72 let this = self.clone();
73 async move { this.listen_loop(endpoint).await }
74 },
75 |res| async move {
76 if let TaskResult::Completed(Err(err)) = res {
77 error!("Listen loop stopped: {err}");
78 }
79 },
80 );
81 }
82
83 self.task_group.spawn(
84 {
85 let this = self.clone();
86 async move { this.refresh_loop().await }
87 },
88 |res| async move {
89 if let TaskResult::Completed(Err(err)) = res {
90 error!("Refresh loop stopped: {err}");
91 }
92 },
93 );
94
95 Ok(())
96 }
97
98 pub fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) -> Result<()> {
100 let resolved_endpoint = Endpoint::Udp(
101 resolved_endpoint.addr()?.clone(),
102 self.config.discovery_port,
103 );
104 *self.listen_endpoint.write() = Some(resolved_endpoint);
105 Ok(())
106 }
107
108 pub async fn shutdown(&self) {
110 self.task_group.cancel().await;
111 }
112
113 async fn refresh_loop(self: Arc<Self>) -> Result<()> {
117 loop {
118 sleep(Duration::from_secs(self.config.refresh_interval)).await;
119 trace!("Start refreshing the routing table...");
120
121 self.monitor.notify(DiscvEvent::RefreshStarted).await;
122
123 let mut entries: Vec<BucketEntry> = vec![];
124 for bucket in self.table.buckets() {
125 for entry in bucket
126 .iter()
127 .filter(|e| !e.is_connected() && !e.is_incompatible())
128 .take(8)
129 {
130 entries.push(entry.clone())
131 }
132 }
133
134 self.clone().do_refresh(&entries).await;
135 }
136 }
137
138 async fn do_refresh(self: Arc<Self>, entries: &[BucketEntry]) {
140 use futures_util::stream::{FuturesUnordered, StreamExt};
141 for chunk in entries.chunks(16) {
143 let mut tasks = FuturesUnordered::new();
144 for bucket_entry in chunk {
145 if bucket_entry.failures >= MAX_FAILURES {
146 self.table.remove_entry(&bucket_entry.entry.key);
147 continue;
148 }
149
150 tasks.push(self.clone().refresh_entry(bucket_entry.clone()))
151 }
152
153 while tasks.next().await.is_some() {}
154 }
155 }
156
157 async fn refresh_entry(self: Arc<Self>, bucket_entry: BucketEntry) {
160 let key = &bucket_entry.entry.key;
161 match self.connect(&bucket_entry.entry).await {
162 Ok(_) => {
163 self.table.update_entry(key, PENDING_ENTRY);
164 }
165 Err(err) => {
166 trace!("Failed to refresh entry {:?}: {err}", key);
167 if bucket_entry.failures >= MAX_FAILURES {
168 self.table.remove_entry(key);
169 return;
170 }
171 self.table.update_entry(key, UNREACHABLE_ENTRY);
172 }
173 }
174 }
175
176 async fn connect(&self, entry: &Entry) -> Result<()> {
180 let mut retry = 0;
181 let endpoint = Endpoint::Udp(entry.addr.clone(), entry.discovery_port);
182 let conn = udp::dial(&endpoint, Default::default(), RefreshMsgCodec {}).await?;
183 let backoff = Backoff::new(100, 5000);
184 while retry < self.config.refresh_connect_retries {
185 match self.send_ping_msg(&conn, &endpoint).await {
186 Ok(()) => return Ok(()),
187 Err(Error::Timeout) => {
188 retry += 1;
189 backoff.sleep().await;
190 }
191 Err(err) => {
192 return Err(err);
193 }
194 }
195 }
196
197 Err(Error::Timeout)
198 }
199
200 async fn listen_loop(self: Arc<Self>, endpoint: Endpoint) -> Result<()> {
203 let conn = match udp::listen(&endpoint, Default::default(), RefreshMsgCodec {}).await {
204 Ok(c) => {
205 self.monitor
206 .notify(ConnEvent::Listening(endpoint.clone()))
207 .await;
208 c
209 }
210 Err(err) => {
211 self.monitor
212 .notify(ConnEvent::ListenFailed(endpoint.clone()))
213 .await;
214 return Err(err.into());
215 }
216 };
217 info!("Start listening on {endpoint}");
218
219 loop {
220 let res = self.listen_to_ping_msg(&conn).await;
221 if let Err(err) = res {
222 trace!("Failed to handle ping msg {err}");
223 self.monitor.notify(ConnEvent::AcceptFailed).await;
224 }
225 }
226 }
227
228 async fn listen_to_ping_msg(&self, conn: &udp::UdpConn<RefreshMsgCodec>) -> Result<()> {
230 let (msg, endpoint) = conn.recv().await?;
231 self.monitor
232 .notify(ConnEvent::Accepted(endpoint.clone()))
233 .await;
234
235 match msg {
236 RefreshMsg::Ping(m) => {
237 let pong_msg = RefreshMsg::Pong(m);
238 conn.send((pong_msg, endpoint.clone())).await?;
239 }
240 RefreshMsg::Pong(_) => return Err(Error::InvalidMsg("Unexpected pong msg".into())),
241 }
242
243 self.monitor.notify(ConnEvent::Disconnected(endpoint)).await;
244 Ok(())
245 }
246
247 async fn send_ping_msg(
249 &self,
250 conn: &udp::UdpConn<RefreshMsgCodec>,
251 endpoint: &Endpoint,
252 ) -> Result<()> {
253 let mut nonce: [u8; 32] = [0; 32];
254 RngCore::fill_bytes(&mut OsRng, &mut nonce);
255 conn.send((RefreshMsg::Ping(nonce), endpoint.clone()))
256 .await?;
257
258 let t = Duration::from_secs(self.config.refresh_response_timeout);
259 let (msg, _) = timeout(t, conn.recv()).await??;
260
261 match msg {
262 RefreshMsg::Pong(n) => {
263 if n != nonce {
264 return Err(Error::InvalidPongMsg);
265 }
266 Ok(())
267 }
268 _ => Err(Error::InvalidMsg("Unexpected ping msg".into())),
269 }
270 }
271}