1mod peer_id;
2
3use std::sync::{Arc, Weak};
4
5use async_channel::{Receiver, Sender};
6use bincode::Encode;
7use log::{error, trace};
8use parking_lot::RwLock;
9
10use karyon_core::{
11 async_runtime::Executor,
12 async_util::{select, Either, TaskGroup, TaskResult},
13 util::decode,
14};
15
16use crate::{
17 connection::{ConnDirection, Connection},
18 endpoint::Endpoint,
19 message::{NetMsgCmd, ProtocolMsg},
20 peer_pool::PeerPool,
21 protocol::{InitProtocol, Protocol, ProtocolEvent, ProtocolID},
22 protocols::HandshakeProtocol,
23 Config, Error, Result,
24};
25
26pub use peer_id::PeerID;
27
28pub struct Peer {
29 own_id: PeerID,
31
32 id: RwLock<Option<PeerID>>,
34
35 peer_pool: Weak<PeerPool>,
37
38 pub(crate) conn: Arc<Connection>,
40
41 stop_chan: (Sender<Result<()>>, Receiver<Result<()>>),
43
44 config: Arc<Config>,
46
47 task_group: TaskGroup,
49}
50
51impl Peer {
52 pub(crate) fn new(
54 own_id: PeerID,
55 peer_pool: Weak<PeerPool>,
56 conn: Arc<Connection>,
57 config: Arc<Config>,
58 ex: Executor,
59 ) -> Arc<Peer> {
60 Arc::new(Peer {
61 own_id,
62 id: RwLock::new(None),
63 peer_pool,
64 conn,
65 config,
66 task_group: TaskGroup::with_executor(ex),
67 stop_chan: async_channel::bounded(1),
68 })
69 }
70
71 pub async fn send<T: Encode>(&self, proto_id: ProtocolID, msg: T) -> Result<()> {
73 self.conn.send(proto_id, msg).await
74 }
75
76 pub async fn recv<P: Protocol>(&self) -> Result<ProtocolEvent> {
78 self.conn.recv::<P>().await
79 }
80
81 pub async fn broadcast<T: Encode>(&self, proto_id: &ProtocolID, msg: &T) {
83 self.peer_pool().broadcast(proto_id, msg).await;
84 }
85
86 pub fn id(&self) -> Option<PeerID> {
88 self.id.read().clone()
89 }
90
91 pub fn own_id(&self) -> &PeerID {
93 &self.own_id
94 }
95
96 pub fn config(&self) -> Arc<Config> {
98 self.config.clone()
99 }
100
101 pub fn remote_endpoint(&self) -> &Endpoint {
103 &self.conn.remote_endpoint
104 }
105
106 pub fn is_inbound(&self) -> bool {
108 match self.conn.direction {
109 ConnDirection::Inbound => true,
110 ConnDirection::Outbound => false,
111 }
112 }
113
114 pub fn direction(&self) -> &ConnDirection {
117 &self.conn.direction
118 }
119
120 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
121 let handshake_protocol = HandshakeProtocol::new(
122 self.clone(),
123 self.peer_pool().protocol_versions.read().await.clone(),
124 );
125
126 let pid = handshake_protocol.init().await?;
127 *self.id.write() = Some(pid);
128
129 Ok(())
130 }
131
132 pub(crate) async fn run(self: Arc<Self>) -> Result<()> {
134 self.run_connect_protocols().await;
135 self.read_loop().await
136 }
137
138 pub(crate) async fn shutdown(self: &Arc<Self>) -> Result<()> {
140 trace!("peer {:?} shutting down", self.id());
141
142 for proto_id in self.peer_pool().protocols.read().await.keys() {
144 let _ = self.conn.emit_msg(proto_id, &ProtocolEvent::Shutdown).await;
145 }
146
147 let _ = self.stop_chan.0.try_send(Ok(()));
152
153 self.conn.disconnect(Ok(())).await?;
154
155 self.task_group.cancel().await;
157 Ok(())
158 }
159
160 async fn run_connect_protocols(self: &Arc<Self>) {
162 for (proto_id, constructor) in self.peer_pool().protocols.read().await.iter() {
163 trace!("peer {:?} run protocol {proto_id}", self.id());
164
165 let protocol = constructor(self.clone());
166
167 let on_failure = {
168 let this = self.clone();
169 let proto_id = proto_id.clone();
170 |result: TaskResult<Result<()>>| async move {
171 if let TaskResult::Completed(res) = result {
172 if res.is_err() {
173 error!("protocol {} stopped", proto_id);
174 }
175 let _ = this.stop_chan.0.try_send(res);
177 }
178 }
179 };
180
181 self.task_group.spawn(protocol.start(), on_failure);
182 }
183 }
184
185 async fn read_loop(&self) -> Result<()> {
187 loop {
188 let fut = select(self.stop_chan.1.recv(), self.conn.recv_inner()).await;
189 let result = match fut {
190 Either::Left(stop_signal) => {
191 trace!("Peer {:?} received a stop signal", self.id());
192 return stop_signal?;
193 }
194 Either::Right(result) => result,
195 };
196
197 let msg = result?;
198
199 match msg.header.command {
200 NetMsgCmd::Protocol => {
201 let msg: ProtocolMsg = decode(&msg.payload)?.0;
202 self.conn
203 .emit_msg(&msg.protocol_id, &ProtocolEvent::Message(msg.payload))
204 .await?;
205 }
206 NetMsgCmd::Shutdown => {
207 return Err(Error::PeerShutdown);
208 }
209 command => return Err(Error::InvalidMsg(format!("Unexpected msg {:?}", command))),
210 }
211 }
212 }
213
214 fn peer_pool(&self) -> Arc<PeerPool> {
216 self.peer_pool.upgrade().unwrap()
217 }
218}