1use std::{collections::HashMap, sync::Arc};
2
3use bincode::Encode;
4use log::{error, info, warn};
5
6use karyon_core::{
7 async_runtime::{lock::RwLock, Executor},
8 async_util::{TaskGroup, TaskResult},
9};
10
11use karyon_net::Endpoint;
12
13use crate::{
14 config::Config,
15 conn_queue::ConnQueue,
16 connection::Connection,
17 monitor::{Monitor, PPEvent},
18 peer::Peer,
19 protocol::{Protocol, ProtocolConstructor, ProtocolID},
20 protocols::PingProtocol,
21 version::Version,
22 Error, PeerID, Result,
23};
24
25pub struct PeerPool {
26 pub id: PeerID,
28
29 conn_queue: Arc<ConnQueue>,
31
32 peers: RwLock<HashMap<PeerID, Arc<Peer>>>,
34
35 pub(crate) protocols: RwLock<HashMap<ProtocolID, Box<ProtocolConstructor>>>,
37
38 pub(crate) protocol_versions: RwLock<HashMap<ProtocolID, Version>>,
40
41 task_group: TaskGroup,
43
44 executor: Executor,
46
47 config: Arc<Config>,
49
50 monitor: Arc<Monitor>,
52}
53
54impl PeerPool {
55 pub fn new(
57 id: &PeerID,
58 conn_queue: Arc<ConnQueue>,
59 config: Arc<Config>,
60 monitor: Arc<Monitor>,
61 executor: Executor,
62 ) -> Arc<Self> {
63 Arc::new(Self {
64 id: id.clone(),
65 conn_queue,
66 peers: RwLock::new(HashMap::new()),
67 protocols: RwLock::new(HashMap::new()),
68 protocol_versions: RwLock::new(HashMap::new()),
69 task_group: TaskGroup::with_executor(executor.clone()),
70 executor,
71 monitor,
72 config,
73 })
74 }
75
76 pub async fn start(self: &Arc<Self>) -> Result<()> {
78 self.setup_core_protocols().await?;
79 self.task_group.spawn(self.clone().run(), |_| async {});
80 Ok(())
81 }
82
83 pub async fn shutdown(&self) {
85 for (_, peer) in self.peers.read().await.iter() {
86 let _ = peer.shutdown().await;
87 }
88
89 self.task_group.cancel().await;
90 }
91
92 pub async fn attach_protocol<P: Protocol>(&self, c: Box<ProtocolConstructor>) -> Result<()> {
94 self.protocols.write().await.insert(P::id(), c);
95 self.protocol_versions
96 .write()
97 .await
98 .insert(P::id(), P::version()?);
99 Ok(())
100 }
101
102 pub async fn broadcast<T: Encode>(&self, proto_id: &ProtocolID, msg: &T) {
104 for (pid, peer) in self.peers.read().await.iter() {
105 if let Err(err) = peer.conn.send(proto_id.to_string(), msg).await {
106 error!("failed to send msg to {pid}: {err}");
107 continue;
108 }
109 }
110 }
111
112 pub async fn contains_peer(&self, pid: &PeerID) -> bool {
114 self.peers.read().await.contains_key(pid)
115 }
116
117 pub async fn peers_len(&self) -> usize {
119 self.peers.read().await.len()
120 }
121
122 pub async fn inbound_peers(&self) -> HashMap<PeerID, Endpoint> {
124 let mut peers = HashMap::new();
125 for (id, peer) in self.peers.read().await.iter() {
126 if peer.is_inbound() {
127 peers.insert(id.clone(), peer.remote_endpoint().clone());
128 }
129 }
130 peers
131 }
132
133 pub async fn outbound_peers(&self) -> HashMap<PeerID, Endpoint> {
135 let mut peers = HashMap::new();
136 for (id, peer) in self.peers.read().await.iter() {
137 if !peer.is_inbound() {
138 peers.insert(id.clone(), peer.remote_endpoint().clone());
139 }
140 }
141 peers
142 }
143
144 async fn run(self: Arc<Self>) {
145 loop {
146 let mut conn = self.conn_queue.next().await;
147
148 for protocol_id in self.protocols.read().await.keys() {
149 conn.register_protocol(protocol_id.to_string()).await;
150 }
151
152 let conn = Arc::new(conn);
153
154 let result = self.new_peer(conn.clone()).await;
155
156 if result.is_err() {
158 let _ = conn.disconnect(result).await;
159 }
160 }
161 }
162
163 async fn new_peer(self: &Arc<Self>, conn: Arc<Connection>) -> Result<()> {
165 let peer = Peer::new(
167 self.id.clone(),
168 Arc::downgrade(self),
169 conn.clone(),
170 self.config.clone(),
171 self.executor.clone(),
172 );
173 peer.init().await?;
174 let pid = peer.id().expect("Get peer id after peer initialization");
175
176 if self.contains_peer(&pid).await {
178 return Err(Error::PeerAlreadyConnected);
179 }
180
181 self.peers.write().await.insert(pid.clone(), peer.clone());
183
184 let on_disconnect = {
185 let this = self.clone();
186 let pid = pid.clone();
187 |result| async move {
188 if let TaskResult::Completed(_) = result {
189 if let Err(err) = this.remove_peer(&pid).await {
190 error!("Failed to remove peer {pid}: {err}");
191 }
192 }
193 }
194 };
195
196 self.task_group.spawn(peer.run(), on_disconnect);
197
198 info!("Add new peer {pid}");
199 self.monitor.notify(PPEvent::NewPeer(pid)).await;
200
201 Ok(())
202 }
203
204 async fn remove_peer(&self, pid: &PeerID) -> Result<()> {
206 let result = self.peers.write().await.remove(pid);
207
208 let peer = match result {
209 Some(p) => p,
210 None => return Ok(()),
211 };
212
213 let _ = peer.shutdown().await;
214
215 self.monitor.notify(PPEvent::RemovePeer(pid.clone())).await;
216
217 warn!("Peer {pid} removed",);
218 Ok(())
219 }
220
221 async fn setup_core_protocols(&self) -> Result<()> {
223 let executor = self.executor.clone();
224 let ping_interval = self.config.ping_interval;
225 let ping_timeout = self.config.ping_timeout;
226 let c = move |peer| PingProtocol::new(peer, ping_interval, ping_timeout, executor.clone());
227 self.attach_protocol::<PingProtocol>(Box::new(c)).await
228 }
229}