1use std::{
2 collections::{HashMap, HashSet},
3 sync::Arc,
4};
5
6use log::{error, info, warn};
7
8use karyon_core::{
9 async_runtime::{lock::RwLock, Executor},
10 async_util::{TaskGroup, TaskResult},
11};
12
13use karyon_eventemitter::{EventEmitter, EventListener, EventTopic, EventValue};
14
15use karyon_net::Endpoint;
16
17use crate::{
18 config::Config,
19 conn_queue::{ConnQueue, QueuedConn},
20 handshake::{handshake, HandshakeParams},
21 monitor::{Monitor, PoolEvent},
22 peer::Peer,
23 protocol::{Protocol, ProtocolConstructor, ProtocolID, ProtocolMeta},
24 Error, PeerID, Result,
25};
26
27#[derive(Hash, PartialEq, Eq, Debug, Clone)]
29pub enum PeerEventTopic {
30 Lifecycle,
31}
32
33#[derive(Debug, Clone, EventValue)]
36pub enum PeerEvent {
37 Added(PeerID),
39 Removed(PeerID),
41 HandshakeFailed(Option<PeerID>),
43}
44
45impl EventTopic for PeerEvent {
46 type Topic = PeerEventTopic;
47 fn topic() -> Self::Topic {
48 PeerEventTopic::Lifecycle
49 }
50}
51
52pub struct PeerPool {
53 pub id: PeerID,
55
56 conn_queue: Arc<ConnQueue>,
58
59 peers: RwLock<HashMap<PeerID, Arc<Peer>>>,
61
62 pub(crate) protocols: RwLock<HashMap<ProtocolID, Box<ProtocolConstructor>>>,
64
65 pub(crate) protocol_meta: RwLock<HashMap<ProtocolID, ProtocolMeta>>,
69
70 peer_emitter: Arc<EventEmitter<PeerEventTopic>>,
73
74 task_group: TaskGroup,
76
77 pub(crate) executor: Executor,
79
80 pub(crate) config: Arc<Config>,
82
83 monitor: Arc<Monitor>,
85}
86
87impl PeerPool {
88 pub fn new(
90 id: &PeerID,
91 conn_queue: Arc<ConnQueue>,
92 config: Arc<Config>,
93 monitor: Arc<Monitor>,
94 executor: Executor,
95 ) -> Arc<Self> {
96 Arc::new(Self {
97 id: id.clone(),
98 conn_queue,
99 peers: RwLock::new(HashMap::new()),
100 protocols: RwLock::new(HashMap::new()),
101 protocol_meta: RwLock::new(HashMap::new()),
102 peer_emitter: EventEmitter::new(),
103 task_group: TaskGroup::with_executor(executor.clone()),
104 executor,
105 monitor,
106 config,
107 })
108 }
109
110 pub fn register_peer_events(&self) -> EventListener<PeerEventTopic, PeerEvent> {
112 self.peer_emitter.register(&PeerEventTopic::Lifecycle)
113 }
114
115 pub async fn start(self: &Arc<Self>) -> Result<()> {
117 self.task_group.spawn(self.clone().run(), |_| async {});
118 Ok(())
119 }
120
121 pub async fn shutdown(&self) {
123 for (_, peer) in self.peers.read().await.iter() {
124 let _ = peer.shutdown().await;
125 }
126
127 self.task_group.cancel().await;
128 }
129
130 pub async fn attach_protocol<P: Protocol>(&self, c: Box<ProtocolConstructor>) -> Result<()> {
133 let id = P::id();
134 self.protocols.write().await.insert(id.clone(), c);
135 self.protocol_meta.write().await.insert(
136 id,
137 ProtocolMeta {
138 version: P::version()?,
139 kind: P::kind(),
140 },
141 );
142 Ok(())
143 }
144
145 pub async fn broadcast(&self, proto_id: &ProtocolID, msg: Vec<u8>) {
147 for (pid, peer) in self.peers.read().await.iter() {
148 if let Err(err) = peer.send(proto_id.to_string(), msg.clone()).await {
149 error!("failed to send msg to {pid}: {err}");
150 continue;
151 }
152 }
153 }
154
155 pub async fn broadcast_to(
157 &self,
158 proto_id: &ProtocolID,
159 msg: Vec<u8>,
160 targets: &HashSet<PeerID>,
161 ) {
162 for (pid, peer) in self.peers.read().await.iter() {
163 if !targets.contains(pid) {
164 continue;
165 }
166 if let Err(err) = peer.send(proto_id.to_string(), msg.clone()).await {
167 error!("failed to send msg to {pid}: {err}");
168 }
169 }
170 }
171
172 pub async fn send_to(
175 &self,
176 peer_id: &PeerID,
177 proto_id: &ProtocolID,
178 msg: Vec<u8>,
179 ) -> Result<()> {
180 let peers = self.peers.read().await;
181 let peer = peers
182 .get(peer_id)
183 .ok_or_else(|| Error::PeerNotFound(peer_id.to_string()))?;
184 peer.send(proto_id.to_string(), msg).await
185 }
186
187 pub async fn peer_protocol_set(&self, pid: &PeerID) -> Option<HashSet<ProtocolID>> {
189 self.peers
190 .read()
191 .await
192 .get(pid)
193 .map(|p| p.negotiated_protocols().clone())
194 }
195
196 pub async fn contains_peer(&self, pid: &PeerID) -> bool {
198 self.peers.read().await.contains_key(pid)
199 }
200
201 pub async fn peers_len(&self) -> usize {
203 self.peers.read().await.len()
204 }
205
206 pub async fn inbound_peers(&self) -> HashMap<PeerID, Endpoint> {
208 let mut peers = HashMap::new();
209 for (id, peer) in self.peers.read().await.iter() {
210 if peer.is_inbound() {
211 peers.insert(id.clone(), peer.remote_endpoint().clone());
212 }
213 }
214 peers
215 }
216
217 pub async fn outbound_peers(&self) -> HashMap<PeerID, Endpoint> {
219 let mut peers = HashMap::new();
220 for (id, peer) in self.peers.read().await.iter() {
221 if !peer.is_inbound() {
222 peers.insert(id.clone(), peer.remote_endpoint().clone());
223 }
224 }
225 peers
226 }
227
228 async fn run(self: Arc<Self>) {
229 loop {
230 let mut queued = self.conn_queue.next().await;
231
232 let meta = self.protocol_meta.read().await.clone();
236
237 let params = HandshakeParams {
238 own_id: &self.id,
239 is_inbound: matches!(queued.direction, crate::peer::ConnDirection::Inbound),
240 config_version: &self.config.version,
241 protocols: &meta,
242 timeout_secs: self.config.handshake_timeout,
243 verified_peer_id: queued.verified_peer_id.as_ref(),
244 };
245 let handshake = handshake(&mut queued.reader, &mut queued.writer, ¶ms).await;
246
247 let (pid, negotiated) = match handshake {
248 Ok(v) => v,
249 Err(err) => {
250 let pid = queued.verified_peer_id.clone();
251 self.monitor
252 .notify(PoolEvent::HandshakeFailed(pid.clone()))
253 .await;
254 let _ = self
255 .peer_emitter
256 .emit(&PeerEvent::HandshakeFailed(pid))
257 .await;
258 let _ = queued.disconnect_signal.send(Err(err)).await;
259 continue;
260 }
261 };
262
263 if let Err(err) = self.new_peer(queued, pid, negotiated).await {
264 error!("new_peer failed: {err}");
265 }
266 }
267 }
268
269 async fn new_peer(
271 self: &Arc<Self>,
272 queued: QueuedConn,
273 pid: PeerID,
274 negotiated: Vec<ProtocolID>,
275 ) -> Result<()> {
276 if self.contains_peer(&pid).await {
277 self.monitor
278 .notify(PoolEvent::PeerAlreadyConnected(pid.clone()))
279 .await;
280 let _ = queued
281 .disconnect_signal
282 .send(Err(Error::PeerAlreadyConnected))
283 .await;
284 return Err(Error::PeerAlreadyConnected);
285 }
286
287 let protocol_ids: Vec<ProtocolID> = self.protocols.read().await.keys().cloned().collect();
288 let negotiated: HashSet<ProtocolID> = negotiated.into_iter().collect();
289
290 let peer = Peer::new(self.clone(), queued, pid.clone(), negotiated, protocol_ids).await?;
291
292 self.peers.write().await.insert(pid.clone(), peer.clone());
293
294 let on_disconnect = {
295 let this = self.clone();
296 let pid = pid.clone();
297 |result| async move {
298 if let TaskResult::Completed(_) = result {
299 if let Err(err) = this.remove_peer(&pid).await {
300 error!("Failed to remove peer {pid}: {err}");
301 }
302 }
303 }
304 };
305
306 self.task_group.spawn(peer.clone().run(), on_disconnect);
307
308 info!("Add new peer {pid}");
309 self.monitor.notify(PoolEvent::NewPeer(pid.clone())).await;
310 let _ = self.peer_emitter.emit(&PeerEvent::Added(pid)).await;
311
312 Ok(())
313 }
314
315 async fn remove_peer(&self, pid: &PeerID) -> Result<()> {
317 let result = self.peers.write().await.remove(pid);
318
319 let peer = match result {
320 Some(p) => p,
321 None => return Ok(()),
322 };
323
324 let _ = peer.shutdown().await;
325
326 self.monitor
327 .notify(PoolEvent::RemovePeer(pid.clone()))
328 .await;
329 let _ = self
330 .peer_emitter
331 .emit(&PeerEvent::Removed(pid.clone()))
332 .await;
333
334 warn!("Peer {pid} removed",);
335 Ok(())
336 }
337}