karyon_p2p/
peer_pool.rs

1use std::{collections::HashMap, sync::Arc};
2
3use bincode::Encode;
4use log::{error, info, warn};
5
6use karyon_core::{
7    async_runtime::{lock::RwLock, Executor},
8    async_util::{TaskGroup, TaskResult},
9};
10
11use karyon_net::Endpoint;
12
13use crate::{
14    config::Config,
15    conn_queue::ConnQueue,
16    connection::Connection,
17    monitor::{Monitor, PPEvent},
18    peer::Peer,
19    protocol::{Protocol, ProtocolConstructor, ProtocolID},
20    protocols::PingProtocol,
21    version::Version,
22    Error, PeerID, Result,
23};
24
25pub struct PeerPool {
26    /// Peer's ID
27    pub id: PeerID,
28
29    /// Connection queue
30    conn_queue: Arc<ConnQueue>,
31
32    /// Holds the running peers.
33    peers: RwLock<HashMap<PeerID, Arc<Peer>>>,
34
35    /// Hashmap contains protocol constructors.
36    pub(crate) protocols: RwLock<HashMap<ProtocolID, Box<ProtocolConstructor>>>,
37
38    /// Hashmap contains protocols with their versions
39    pub(crate) protocol_versions: RwLock<HashMap<ProtocolID, Version>>,
40
41    /// Managing spawned tasks.
42    task_group: TaskGroup,
43
44    /// A global Executor
45    executor: Executor,
46
47    /// The Configuration for the P2P network.
48    config: Arc<Config>,
49
50    /// Responsible for network and system monitoring.
51    monitor: Arc<Monitor>,
52}
53
54impl PeerPool {
55    /// Creates a new PeerPool
56    pub fn new(
57        id: &PeerID,
58        conn_queue: Arc<ConnQueue>,
59        config: Arc<Config>,
60        monitor: Arc<Monitor>,
61        executor: Executor,
62    ) -> Arc<Self> {
63        Arc::new(Self {
64            id: id.clone(),
65            conn_queue,
66            peers: RwLock::new(HashMap::new()),
67            protocols: RwLock::new(HashMap::new()),
68            protocol_versions: RwLock::new(HashMap::new()),
69            task_group: TaskGroup::with_executor(executor.clone()),
70            executor,
71            monitor,
72            config,
73        })
74    }
75
76    /// Starts the [`PeerPool`]
77    pub async fn start(self: &Arc<Self>) -> Result<()> {
78        self.setup_core_protocols().await?;
79        self.task_group.spawn(self.clone().run(), |_| async {});
80        Ok(())
81    }
82
83    /// Shuts down
84    pub async fn shutdown(&self) {
85        for (_, peer) in self.peers.read().await.iter() {
86            let _ = peer.shutdown().await;
87        }
88
89        self.task_group.cancel().await;
90    }
91
92    /// Attach a custom protocol to the network
93    pub async fn attach_protocol<P: Protocol>(&self, c: Box<ProtocolConstructor>) -> Result<()> {
94        self.protocols.write().await.insert(P::id(), c);
95        self.protocol_versions
96            .write()
97            .await
98            .insert(P::id(), P::version()?);
99        Ok(())
100    }
101
102    /// Broadcast a message to all connected peers using the specified protocol.
103    pub async fn broadcast<T: Encode>(&self, proto_id: &ProtocolID, msg: &T) {
104        for (pid, peer) in self.peers.read().await.iter() {
105            if let Err(err) = peer.conn.send(proto_id.to_string(), msg).await {
106                error!("failed to send msg to {pid}: {err}");
107                continue;
108            }
109        }
110    }
111
112    /// Checks if the peer list contains a peer with the given peer id
113    pub async fn contains_peer(&self, pid: &PeerID) -> bool {
114        self.peers.read().await.contains_key(pid)
115    }
116
117    /// Returns the number of currently connected peers.
118    pub async fn peers_len(&self) -> usize {
119        self.peers.read().await.len()
120    }
121
122    /// Returns a map of inbound peers with their endpoints.
123    pub async fn inbound_peers(&self) -> HashMap<PeerID, Endpoint> {
124        let mut peers = HashMap::new();
125        for (id, peer) in self.peers.read().await.iter() {
126            if peer.is_inbound() {
127                peers.insert(id.clone(), peer.remote_endpoint().clone());
128            }
129        }
130        peers
131    }
132
133    /// Returns a map of outbound peers with their endpoints.
134    pub async fn outbound_peers(&self) -> HashMap<PeerID, Endpoint> {
135        let mut peers = HashMap::new();
136        for (id, peer) in self.peers.read().await.iter() {
137            if !peer.is_inbound() {
138                peers.insert(id.clone(), peer.remote_endpoint().clone());
139            }
140        }
141        peers
142    }
143
144    async fn run(self: Arc<Self>) {
145        loop {
146            let mut conn = self.conn_queue.next().await;
147
148            for protocol_id in self.protocols.read().await.keys() {
149                conn.register_protocol(protocol_id.to_string()).await;
150            }
151
152            let conn = Arc::new(conn);
153
154            let result = self.new_peer(conn.clone()).await;
155
156            // Disconnect if there is an error when adding a peer.
157            if result.is_err() {
158                let _ = conn.disconnect(result).await;
159            }
160        }
161    }
162
163    /// Add a new peer to the peer list.
164    async fn new_peer(self: &Arc<Self>, conn: Arc<Connection>) -> Result<()> {
165        // Create a new peer
166        let peer = Peer::new(
167            self.id.clone(),
168            Arc::downgrade(self),
169            conn.clone(),
170            self.config.clone(),
171            self.executor.clone(),
172        );
173        peer.init().await?;
174        let pid = peer.id().expect("Get peer id after peer initialization");
175
176        // TODO: Consider restricting the subnet for inbound connections
177        if self.contains_peer(&pid).await {
178            return Err(Error::PeerAlreadyConnected);
179        }
180
181        // Insert the new peer
182        self.peers.write().await.insert(pid.clone(), peer.clone());
183
184        let on_disconnect = {
185            let this = self.clone();
186            let pid = pid.clone();
187            |result| async move {
188                if let TaskResult::Completed(_) = result {
189                    if let Err(err) = this.remove_peer(&pid).await {
190                        error!("Failed to remove peer {pid}: {err}");
191                    }
192                }
193            }
194        };
195
196        self.task_group.spawn(peer.run(), on_disconnect);
197
198        info!("Add new peer {pid}");
199        self.monitor.notify(PPEvent::NewPeer(pid)).await;
200
201        Ok(())
202    }
203
204    /// Shuts down the peer and remove it from the peer list.
205    async fn remove_peer(&self, pid: &PeerID) -> Result<()> {
206        let result = self.peers.write().await.remove(pid);
207
208        let peer = match result {
209            Some(p) => p,
210            None => return Ok(()),
211        };
212
213        let _ = peer.shutdown().await;
214
215        self.monitor.notify(PPEvent::RemovePeer(pid.clone())).await;
216
217        warn!("Peer {pid} removed",);
218        Ok(())
219    }
220
221    /// Attach the core protocols.
222    async fn setup_core_protocols(&self) -> Result<()> {
223        let executor = self.executor.clone();
224        let ping_interval = self.config.ping_interval;
225        let ping_timeout = self.config.ping_timeout;
226        let c = move |peer| PingProtocol::new(peer, ping_interval, ping_timeout, executor.clone());
227        self.attach_protocol::<PingProtocol>(Box::new(c)).await
228    }
229}