karyon_p2p/discovery/
refresh.rsuse std::{sync::Arc, time::Duration};
use bincode::{Decode, Encode};
use log::{error, info, trace};
use parking_lot::RwLock;
use rand::{rngs::OsRng, RngCore};
use karyon_core::{
async_runtime::Executor,
async_util::{sleep, timeout, Backoff, TaskGroup, TaskResult},
};
use karyon_net::{udp, Connection, Endpoint};
use crate::{
codec::RefreshMsgCodec,
message::RefreshMsg,
monitor::{ConnEvent, DiscvEvent, Monitor},
routing_table::{BucketEntry, Entry, RoutingTable, PENDING_ENTRY, UNREACHABLE_ENTRY},
Config, Error, Result,
};
pub const MAX_FAILURES: u32 = 3;
#[derive(Decode, Encode, Debug, Clone)]
pub struct PingMsg(pub [u8; 32]);
#[derive(Decode, Encode, Debug)]
pub struct PongMsg(pub [u8; 32]);
pub struct RefreshService {
table: Arc<RoutingTable>,
listen_endpoint: RwLock<Option<Endpoint>>,
task_group: TaskGroup,
config: Arc<Config>,
monitor: Arc<Monitor>,
}
impl RefreshService {
pub fn new(
config: Arc<Config>,
table: Arc<RoutingTable>,
monitor: Arc<Monitor>,
executor: Executor,
) -> Self {
Self {
table,
listen_endpoint: RwLock::new(None),
task_group: TaskGroup::with_executor(executor.clone()),
config,
monitor,
}
}
pub async fn start(self: &Arc<Self>) -> Result<()> {
if let Some(endpoint) = self.listen_endpoint.read().as_ref() {
let endpoint = endpoint.clone();
self.task_group.spawn(
{
let this = self.clone();
async move { this.listen_loop(endpoint).await }
},
|res| async move {
if let TaskResult::Completed(Err(err)) = res {
error!("Listen loop stopped: {err}");
}
},
);
}
self.task_group.spawn(
{
let this = self.clone();
async move { this.refresh_loop().await }
},
|res| async move {
if let TaskResult::Completed(Err(err)) = res {
error!("Refresh loop stopped: {err}");
}
},
);
Ok(())
}
pub fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) -> Result<()> {
let resolved_endpoint = Endpoint::Udp(
resolved_endpoint.addr()?.clone(),
self.config.discovery_port,
);
*self.listen_endpoint.write() = Some(resolved_endpoint);
Ok(())
}
pub async fn shutdown(&self) {
self.task_group.cancel().await;
}
async fn refresh_loop(self: Arc<Self>) -> Result<()> {
loop {
sleep(Duration::from_secs(self.config.refresh_interval)).await;
trace!("Start refreshing the routing table...");
self.monitor.notify(DiscvEvent::RefreshStarted).await;
let mut entries: Vec<BucketEntry> = vec![];
for bucket in self.table.buckets() {
for entry in bucket
.iter()
.filter(|e| !e.is_connected() && !e.is_incompatible())
.take(8)
{
entries.push(entry.clone())
}
}
self.clone().do_refresh(&entries).await;
}
}
async fn do_refresh(self: Arc<Self>, entries: &[BucketEntry]) {
use futures_util::stream::{FuturesUnordered, StreamExt};
for chunk in entries.chunks(16) {
let mut tasks = FuturesUnordered::new();
for bucket_entry in chunk {
if bucket_entry.failures >= MAX_FAILURES {
self.table.remove_entry(&bucket_entry.entry.key);
continue;
}
tasks.push(self.clone().refresh_entry(bucket_entry.clone()))
}
while tasks.next().await.is_some() {}
}
}
async fn refresh_entry(self: Arc<Self>, bucket_entry: BucketEntry) {
let key = &bucket_entry.entry.key;
match self.connect(&bucket_entry.entry).await {
Ok(_) => {
self.table.update_entry(key, PENDING_ENTRY);
}
Err(err) => {
trace!("Failed to refresh entry {:?}: {err}", key);
if bucket_entry.failures >= MAX_FAILURES {
self.table.remove_entry(key);
return;
}
self.table.update_entry(key, UNREACHABLE_ENTRY);
}
}
}
async fn connect(&self, entry: &Entry) -> Result<()> {
let mut retry = 0;
let endpoint = Endpoint::Udp(entry.addr.clone(), entry.discovery_port);
let conn = udp::dial(&endpoint, Default::default(), RefreshMsgCodec {}).await?;
let backoff = Backoff::new(100, 5000);
while retry < self.config.refresh_connect_retries {
match self.send_ping_msg(&conn, &endpoint).await {
Ok(()) => return Ok(()),
Err(Error::Timeout) => {
retry += 1;
backoff.sleep().await;
}
Err(err) => {
return Err(err);
}
}
}
Err(Error::Timeout)
}
async fn listen_loop(self: Arc<Self>, endpoint: Endpoint) -> Result<()> {
let conn = match udp::listen(&endpoint, Default::default(), RefreshMsgCodec {}).await {
Ok(c) => {
self.monitor
.notify(ConnEvent::Listening(endpoint.clone()))
.await;
c
}
Err(err) => {
self.monitor
.notify(ConnEvent::ListenFailed(endpoint.clone()))
.await;
return Err(err.into());
}
};
info!("Start listening on {endpoint}");
loop {
let res = self.listen_to_ping_msg(&conn).await;
if let Err(err) = res {
trace!("Failed to handle ping msg {err}");
self.monitor.notify(ConnEvent::AcceptFailed).await;
}
}
}
async fn listen_to_ping_msg(&self, conn: &udp::UdpConn<RefreshMsgCodec>) -> Result<()> {
let (msg, endpoint) = conn.recv().await?;
self.monitor
.notify(ConnEvent::Accepted(endpoint.clone()))
.await;
match msg {
RefreshMsg::Ping(m) => {
let pong_msg = RefreshMsg::Pong(m);
conn.send((pong_msg, endpoint.clone())).await?;
}
RefreshMsg::Pong(_) => return Err(Error::InvalidMsg("Unexpected pong msg".into())),
}
self.monitor.notify(ConnEvent::Disconnected(endpoint)).await;
Ok(())
}
async fn send_ping_msg(
&self,
conn: &udp::UdpConn<RefreshMsgCodec>,
endpoint: &Endpoint,
) -> Result<()> {
let mut nonce: [u8; 32] = [0; 32];
RngCore::fill_bytes(&mut OsRng, &mut nonce);
conn.send((RefreshMsg::Ping(nonce), endpoint.clone()))
.await?;
let t = Duration::from_secs(self.config.refresh_response_timeout);
let (msg, _) = timeout(t, conn.recv()).await??;
match msg {
RefreshMsg::Pong(n) => {
if n != nonce {
return Err(Error::InvalidPongMsg);
}
Ok(())
}
_ => Err(Error::InvalidMsg("Unexpected ping msg".into())),
}
}
}