1mod connection;
2mod peer_id;
3
4use std::{
5 collections::HashSet,
6 fmt,
7 sync::{Arc, Weak},
8};
9
10use async_channel::{Receiver, Sender};
11use log::{error, trace};
12
13use karyon_core::{
14 async_runtime::Executor,
15 async_util::{TaskGroup, TaskResult},
16};
17
18use crate::{
19 conn_queue::QueuedConn,
20 endpoint::Endpoint,
21 peer_pool::PeerPool,
22 protocol::{PeerConn, ProtocolEvent, ProtocolID},
23 Config, Result,
24};
25
26pub use peer_id::PeerID;
27
28use connection::PeerConnection;
29
30#[derive(Clone, Debug)]
31pub enum ConnDirection {
32 Inbound,
33 Outbound,
34}
35
36impl fmt::Display for ConnDirection {
37 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38 match self {
39 ConnDirection::Inbound => write!(f, "Inbound"),
40 ConnDirection::Outbound => write!(f, "Outbound"),
41 }
42 }
43}
44
45pub struct Peer {
48 own_id: PeerID,
49 id: PeerID,
50 peer_pool: Weak<PeerPool>,
51
52 direction: ConnDirection,
53 remote_endpoint: Endpoint,
54
55 connection: Arc<dyn PeerConnection>,
56 disconnect_signal: Sender<Result<()>>,
57
58 negotiated_protocols: HashSet<ProtocolID>,
59 stop_chan: (Sender<Result<()>>, Receiver<Result<()>>),
60 config: Arc<Config>,
61 executor: Executor,
62 task_group: TaskGroup,
63}
64
65impl Peer {
66 pub async fn send(&self, proto_id: ProtocolID, msg: Vec<u8>) -> Result<()> {
67 self.connection.send(&proto_id, msg).await
68 }
69
70 pub async fn recv(&self, proto_id: &ProtocolID) -> Result<ProtocolEvent> {
71 self.connection.recv(proto_id).await
72 }
73
74 pub async fn broadcast(&self, proto_id: &ProtocolID, msg: Vec<u8>) {
75 self.peer_pool().broadcast(proto_id, msg).await;
76 }
77
78 pub fn id(&self) -> &PeerID {
79 &self.id
80 }
81
82 pub fn own_id(&self) -> &PeerID {
83 &self.own_id
84 }
85
86 pub fn config(&self) -> Arc<Config> {
87 self.config.clone()
88 }
89
90 pub fn executor(&self) -> Executor {
91 self.executor.clone()
92 }
93
94 pub fn remote_endpoint(&self) -> &Endpoint {
95 &self.remote_endpoint
96 }
97
98 pub fn is_inbound(&self) -> bool {
99 matches!(self.direction, ConnDirection::Inbound)
100 }
101
102 pub fn direction(&self) -> &ConnDirection {
103 &self.direction
104 }
105
106 pub fn negotiated_protocols(&self) -> &HashSet<ProtocolID> {
107 &self.negotiated_protocols
108 }
109
110 pub(crate) async fn run(self: Arc<Self>) -> Result<()> {
111 self.run_connect_protocols().await;
112 let stop_signal = self.stop_chan.1.recv().await?;
113 stop_signal
114 }
115
116 pub(crate) async fn shutdown(self: &Arc<Self>) -> Result<()> {
117 trace!("peer {} shutting down", self.id);
118
119 let _ = self.connection.shutdown().await;
120 let _ = self.stop_chan.0.try_send(Ok(()));
121
122 let _ = self.disconnect_signal.send(Ok(())).await;
123 self.task_group.cancel().await;
124 Ok(())
125 }
126
127 async fn run_connect_protocols(self: &Arc<Self>) {
128 for (proto_id, constructor) in self.peer_pool().protocols.read().await.iter() {
129 if !self.negotiated_protocols.contains(proto_id) {
130 trace!("peer {} skip protocol {proto_id} (not negotiated)", self.id);
131 continue;
132 }
133 trace!("peer {} run protocol {proto_id}", self.id);
134
135 let peer_conn = PeerConn::new(self.clone(), proto_id.clone());
136 let protocol = match constructor(peer_conn) {
137 Ok(p) => p,
138 Err(err) => {
139 error!("Failed to build protocol {proto_id}: {err}");
140 continue;
141 }
142 };
143
144 let on_failure = {
145 let this = self.clone();
146 let proto_id = proto_id.clone();
147 |result: TaskResult<Result<()>>| async move {
148 if let TaskResult::Completed(res) = result {
149 if res.is_err() {
150 error!("protocol {proto_id} stopped");
151 }
152 let _ = this.stop_chan.0.try_send(res);
153 }
154 }
155 };
156
157 self.task_group.spawn(protocol.start(), on_failure);
158 }
159 }
160
161 fn peer_pool(&self) -> Arc<PeerPool> {
162 self.peer_pool.upgrade().unwrap()
163 }
164}
165
166impl Peer {
167 pub(crate) async fn new(
168 peer_pool: Arc<PeerPool>,
169 queued: QueuedConn,
170 id: PeerID,
171 negotiated_protocols: HashSet<ProtocolID>,
172 protocol_ids: impl IntoIterator<Item = ProtocolID> + Clone,
173 ) -> Result<Arc<Self>> {
174 let own_id = peer_pool.id.clone();
175 let config = peer_pool.config.clone();
176 let executor = peer_pool.executor.clone();
177 let task_group = TaskGroup::with_executor(executor.clone());
178 let stop_chan = async_channel::bounded::<Result<()>>(1);
179
180 let remote_endpoint = queued.remote_endpoint.clone();
181 let direction = queued.direction.clone();
182 let disconnect_signal = queued.disconnect_signal.clone();
183
184 let connection = connection::from_queued(
185 queued,
186 &negotiated_protocols,
187 protocol_ids,
188 &task_group,
189 stop_chan.0.clone(),
190 )
191 .await?;
192
193 let peer_pool_weak = Arc::downgrade(&peer_pool);
194 Ok(Arc::new(Peer {
195 own_id,
196 id,
197 peer_pool: peer_pool_weak,
198 direction,
199 remote_endpoint,
200 connection,
201 disconnect_signal,
202 negotiated_protocols,
203 stop_chan,
204 config,
205 executor,
206 task_group,
207 }))
208 }
209}