Skip to main content

karyon_p2p/
peer_pool.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::Arc,
4};
5
6use log::{error, info, warn};
7
8use karyon_core::{
9    async_runtime::{lock::RwLock, Executor},
10    async_util::{TaskGroup, TaskResult},
11};
12
13use karyon_eventemitter::{EventEmitter, EventListener, EventTopic, EventValue};
14
15use karyon_net::Endpoint;
16
17use crate::{
18    config::Config,
19    conn_queue::{ConnQueue, QueuedConn},
20    handshake::{handshake, HandshakeParams},
21    monitor::{Monitor, PoolEvent},
22    peer::Peer,
23    protocol::{Protocol, ProtocolConstructor, ProtocolID, ProtocolMeta},
24    Error, PeerID, Result,
25};
26
27/// Topic key for the peer-lifecycle event emitter.
28#[derive(Hash, PartialEq, Eq, Debug, Clone)]
29pub enum PeerEventTopic {
30    Lifecycle,
31}
32
33/// Peer-lifecycle events. Each registered listener receives every
34/// event independently.
35#[derive(Debug, Clone, EventValue)]
36pub enum PeerEvent {
37    /// Peer added after a successful handshake.
38    Added(PeerID),
39    /// Previously-added peer removed.
40    Removed(PeerID),
41    /// Handshake failed before the peer was added.
42    HandshakeFailed(Option<PeerID>),
43}
44
45impl EventTopic for PeerEvent {
46    type Topic = PeerEventTopic;
47    fn topic() -> Self::Topic {
48        PeerEventTopic::Lifecycle
49    }
50}
51
52pub struct PeerPool {
53    /// Peer's ID
54    pub id: PeerID,
55
56    /// Connection queue
57    conn_queue: Arc<ConnQueue>,
58
59    /// Holds the running peers.
60    peers: RwLock<HashMap<PeerID, Arc<Peer>>>,
61
62    /// Hashmap contains protocol constructors.
63    pub(crate) protocols: RwLock<HashMap<ProtocolID, Box<ProtocolConstructor>>>,
64
65    /// Per-protocol metadata (version + kind, extensible). Keyed by
66    /// protocol id. Source of truth for the handshake's mandatory check
67    /// and version negotiation.
68    pub(crate) protocol_meta: RwLock<HashMap<ProtocolID, ProtocolMeta>>,
69
70    /// Peer-lifecycle event emitter. Each registered listener gets its
71    /// own copy of every event.
72    peer_emitter: Arc<EventEmitter<PeerEventTopic>>,
73
74    /// Managing spawned tasks.
75    task_group: TaskGroup,
76
77    /// A global Executor
78    pub(crate) executor: Executor,
79
80    /// The Configuration for the P2P network.
81    pub(crate) config: Arc<Config>,
82
83    /// Responsible for network and system monitoring.
84    monitor: Arc<Monitor>,
85}
86
87impl PeerPool {
88    /// Creates a new PeerPool
89    pub fn new(
90        id: &PeerID,
91        conn_queue: Arc<ConnQueue>,
92        config: Arc<Config>,
93        monitor: Arc<Monitor>,
94        executor: Executor,
95    ) -> Arc<Self> {
96        Arc::new(Self {
97            id: id.clone(),
98            conn_queue,
99            peers: RwLock::new(HashMap::new()),
100            protocols: RwLock::new(HashMap::new()),
101            protocol_meta: RwLock::new(HashMap::new()),
102            peer_emitter: EventEmitter::new(),
103            task_group: TaskGroup::with_executor(executor.clone()),
104            executor,
105            monitor,
106            config,
107        })
108    }
109
110    /// Register a listener for the peer-lifecycle events.
111    pub fn register_peer_events(&self) -> EventListener<PeerEventTopic, PeerEvent> {
112        self.peer_emitter.register(&PeerEventTopic::Lifecycle)
113    }
114
115    /// Starts the [`PeerPool`]
116    pub async fn start(self: &Arc<Self>) -> Result<()> {
117        self.task_group.spawn(self.clone().run(), |_| async {});
118        Ok(())
119    }
120
121    /// Shuts down
122    pub async fn shutdown(&self) {
123        for (_, peer) in self.peers.read().await.iter() {
124            let _ = peer.shutdown().await;
125        }
126
127        self.task_group.cancel().await;
128    }
129
130    /// Register a protocol's user-supplied constructor and metadata.
131    /// Bloom advertising is handled by `Node::attach_protocol`.
132    pub async fn attach_protocol<P: Protocol>(&self, c: Box<ProtocolConstructor>) -> Result<()> {
133        let id = P::id();
134        self.protocols.write().await.insert(id.clone(), c);
135        self.protocol_meta.write().await.insert(
136            id,
137            ProtocolMeta {
138                version: P::version()?,
139                kind: P::kind(),
140            },
141        );
142        Ok(())
143    }
144
145    /// Broadcast a message to all connected peers.
146    pub async fn broadcast(&self, proto_id: &ProtocolID, msg: Vec<u8>) {
147        for (pid, peer) in self.peers.read().await.iter() {
148            if let Err(err) = peer.send(proto_id.to_string(), msg.clone()).await {
149                error!("failed to send msg to {pid}: {err}");
150                continue;
151            }
152        }
153    }
154
155    /// Broadcast a message to a specific set of peers.
156    pub async fn broadcast_to(
157        &self,
158        proto_id: &ProtocolID,
159        msg: Vec<u8>,
160        targets: &HashSet<PeerID>,
161    ) {
162        for (pid, peer) in self.peers.read().await.iter() {
163            if !targets.contains(pid) {
164                continue;
165            }
166            if let Err(err) = peer.send(proto_id.to_string(), msg.clone()).await {
167                error!("failed to send msg to {pid}: {err}");
168            }
169        }
170    }
171
172    /// Send a message to a specific peer on the given protocol. Returns
173    /// `PeerNotFound` if the peer is not currently in the pool.
174    pub async fn send_to(
175        &self,
176        peer_id: &PeerID,
177        proto_id: &ProtocolID,
178        msg: Vec<u8>,
179    ) -> Result<()> {
180        let peers = self.peers.read().await;
181        let peer = peers
182            .get(peer_id)
183            .ok_or_else(|| Error::PeerNotFound(peer_id.to_string()))?;
184        peer.send(proto_id.to_string(), msg).await
185    }
186
187    /// Returns the negotiated protocol set for a peer.
188    pub async fn peer_protocol_set(&self, pid: &PeerID) -> Option<HashSet<ProtocolID>> {
189        self.peers
190            .read()
191            .await
192            .get(pid)
193            .map(|p| p.negotiated_protocols().clone())
194    }
195
196    /// Checks if the peer list contains a peer with the given peer id
197    pub async fn contains_peer(&self, pid: &PeerID) -> bool {
198        self.peers.read().await.contains_key(pid)
199    }
200
201    /// Returns the number of currently connected peers.
202    pub async fn peers_len(&self) -> usize {
203        self.peers.read().await.len()
204    }
205
206    /// Returns a map of inbound peers with their endpoints.
207    pub async fn inbound_peers(&self) -> HashMap<PeerID, Endpoint> {
208        let mut peers = HashMap::new();
209        for (id, peer) in self.peers.read().await.iter() {
210            if peer.is_inbound() {
211                peers.insert(id.clone(), peer.remote_endpoint().clone());
212            }
213        }
214        peers
215    }
216
217    /// Returns a map of outbound peers with their endpoints.
218    pub async fn outbound_peers(&self) -> HashMap<PeerID, Endpoint> {
219        let mut peers = HashMap::new();
220        for (id, peer) in self.peers.read().await.iter() {
221            if !peer.is_inbound() {
222                peers.insert(id.clone(), peer.remote_endpoint().clone());
223            }
224        }
225        peers
226    }
227
228    async fn run(self: Arc<Self>) {
229        loop {
230            let mut queued = self.conn_queue.next().await;
231
232            // Snapshot the protocol metadata so we don't hold a lock
233            // across the handshake. Drives both version negotiation
234            // and the mandatory-subset check.
235            let meta = self.protocol_meta.read().await.clone();
236
237            let params = HandshakeParams {
238                own_id: &self.id,
239                is_inbound: matches!(queued.direction, crate::peer::ConnDirection::Inbound),
240                config_version: &self.config.version,
241                protocols: &meta,
242                timeout_secs: self.config.handshake_timeout,
243                verified_peer_id: queued.verified_peer_id.as_ref(),
244            };
245            let handshake = handshake(&mut queued.reader, &mut queued.writer, &params).await;
246
247            let (pid, negotiated) = match handshake {
248                Ok(v) => v,
249                Err(err) => {
250                    let pid = queued.verified_peer_id.clone();
251                    self.monitor
252                        .notify(PoolEvent::HandshakeFailed(pid.clone()))
253                        .await;
254                    let _ = self
255                        .peer_emitter
256                        .emit(&PeerEvent::HandshakeFailed(pid))
257                        .await;
258                    let _ = queued.disconnect_signal.send(Err(err)).await;
259                    continue;
260                }
261            };
262
263            if let Err(err) = self.new_peer(queued, pid, negotiated).await {
264                error!("new_peer failed: {err}");
265            }
266        }
267    }
268
269    /// Build a Peer from a post-handshake `QueuedConn` and run it.
270    async fn new_peer(
271        self: &Arc<Self>,
272        queued: QueuedConn,
273        pid: PeerID,
274        negotiated: Vec<ProtocolID>,
275    ) -> Result<()> {
276        if self.contains_peer(&pid).await {
277            self.monitor
278                .notify(PoolEvent::PeerAlreadyConnected(pid.clone()))
279                .await;
280            let _ = queued
281                .disconnect_signal
282                .send(Err(Error::PeerAlreadyConnected))
283                .await;
284            return Err(Error::PeerAlreadyConnected);
285        }
286
287        let protocol_ids: Vec<ProtocolID> = self.protocols.read().await.keys().cloned().collect();
288        let negotiated: HashSet<ProtocolID> = negotiated.into_iter().collect();
289
290        let peer = Peer::new(self.clone(), queued, pid.clone(), negotiated, protocol_ids).await?;
291
292        self.peers.write().await.insert(pid.clone(), peer.clone());
293
294        let on_disconnect = {
295            let this = self.clone();
296            let pid = pid.clone();
297            |result| async move {
298                if let TaskResult::Completed(_) = result {
299                    if let Err(err) = this.remove_peer(&pid).await {
300                        error!("Failed to remove peer {pid}: {err}");
301                    }
302                }
303            }
304        };
305
306        self.task_group.spawn(peer.clone().run(), on_disconnect);
307
308        info!("Add new peer {pid}");
309        self.monitor.notify(PoolEvent::NewPeer(pid.clone())).await;
310        let _ = self.peer_emitter.emit(&PeerEvent::Added(pid)).await;
311
312        Ok(())
313    }
314
315    /// Shuts down the peer and remove it from the peer list.
316    async fn remove_peer(&self, pid: &PeerID) -> Result<()> {
317        let result = self.peers.write().await.remove(pid);
318
319        let peer = match result {
320            Some(p) => p,
321            None => return Ok(()),
322        };
323
324        let _ = peer.shutdown().await;
325
326        self.monitor
327            .notify(PoolEvent::RemovePeer(pid.clone()))
328            .await;
329        let _ = self
330            .peer_emitter
331            .emit(&PeerEvent::Removed(pid.clone()))
332            .await;
333
334        warn!("Peer {pid} removed",);
335        Ok(())
336    }
337}