karyon_p2p/peer/
mod.rs

1mod peer_id;
2
3use std::sync::{Arc, Weak};
4
5use async_channel::{Receiver, Sender};
6use bincode::Encode;
7use log::{error, trace};
8use parking_lot::RwLock;
9
10use karyon_core::{
11    async_runtime::Executor,
12    async_util::{select, Either, TaskGroup, TaskResult},
13    util::decode,
14};
15
16use crate::{
17    connection::{ConnDirection, Connection},
18    endpoint::Endpoint,
19    message::{NetMsgCmd, ProtocolMsg},
20    peer_pool::PeerPool,
21    protocol::{InitProtocol, Protocol, ProtocolEvent, ProtocolID},
22    protocols::HandshakeProtocol,
23    Config, Error, Result,
24};
25
26pub use peer_id::PeerID;
27
28pub struct Peer {
29    /// Own ID
30    own_id: PeerID,
31
32    /// Peer's ID
33    id: RwLock<Option<PeerID>>,
34
35    /// A weak pointer to [`PeerPool`]
36    peer_pool: Weak<PeerPool>,
37
38    /// Holds the peer connection
39    pub(crate) conn: Arc<Connection>,
40
41    /// This channel is used to send a stop signal to the read loop.
42    stop_chan: (Sender<Result<()>>, Receiver<Result<()>>),
43
44    /// The Configuration for the P2P network.
45    config: Arc<Config>,
46
47    /// Managing spawned tasks.
48    task_group: TaskGroup,
49}
50
51impl Peer {
52    /// Creates a new peer
53    pub(crate) fn new(
54        own_id: PeerID,
55        peer_pool: Weak<PeerPool>,
56        conn: Arc<Connection>,
57        config: Arc<Config>,
58        ex: Executor,
59    ) -> Arc<Peer> {
60        Arc::new(Peer {
61            own_id,
62            id: RwLock::new(None),
63            peer_pool,
64            conn,
65            config,
66            task_group: TaskGroup::with_executor(ex),
67            stop_chan: async_channel::bounded(1),
68        })
69    }
70
71    /// Send a msg to this peer connection using the specified protocol.
72    pub async fn send<T: Encode>(&self, proto_id: ProtocolID, msg: T) -> Result<()> {
73        self.conn.send(proto_id, msg).await
74    }
75
76    /// Receives a new msg from this peer connection.
77    pub async fn recv<P: Protocol>(&self) -> Result<ProtocolEvent> {
78        self.conn.recv::<P>().await
79    }
80
81    /// Broadcast a message to all connected peers using the specified protocol.
82    pub async fn broadcast<T: Encode>(&self, proto_id: &ProtocolID, msg: &T) {
83        self.peer_pool().broadcast(proto_id, msg).await;
84    }
85
86    /// Returns the peer's ID
87    pub fn id(&self) -> Option<PeerID> {
88        self.id.read().clone()
89    }
90
91    /// Returns own ID
92    pub fn own_id(&self) -> &PeerID {
93        &self.own_id
94    }
95
96    /// Returns the [`Config`]
97    pub fn config(&self) -> Arc<Config> {
98        self.config.clone()
99    }
100
101    /// Returns the remote endpoint for the peer
102    pub fn remote_endpoint(&self) -> &Endpoint {
103        &self.conn.remote_endpoint
104    }
105
106    /// Check if the connection is Inbound
107    pub fn is_inbound(&self) -> bool {
108        match self.conn.direction {
109            ConnDirection::Inbound => true,
110            ConnDirection::Outbound => false,
111        }
112    }
113
114    /// Returns the direction of the connection, which can be either `Inbound`
115    /// or `Outbound`.
116    pub fn direction(&self) -> &ConnDirection {
117        &self.conn.direction
118    }
119
120    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
121        let handshake_protocol = HandshakeProtocol::new(
122            self.clone(),
123            self.peer_pool().protocol_versions.read().await.clone(),
124        );
125
126        let pid = handshake_protocol.init().await?;
127        *self.id.write() = Some(pid);
128
129        Ok(())
130    }
131
132    /// Run the peer
133    pub(crate) async fn run(self: Arc<Self>) -> Result<()> {
134        self.run_connect_protocols().await;
135        self.read_loop().await
136    }
137
138    /// Shuts down the peer
139    pub(crate) async fn shutdown(self: &Arc<Self>) -> Result<()> {
140        trace!("peer {:?} shutting down", self.id());
141
142        // Send shutdown event to the attached protocols
143        for proto_id in self.peer_pool().protocols.read().await.keys() {
144            let _ = self.conn.emit_msg(proto_id, &ProtocolEvent::Shutdown).await;
145        }
146
147        // Send a stop signal to the read loop
148        //
149        // No need to handle the error here; a dropped channel and
150        // sendig a stop signal have the same effect.
151        let _ = self.stop_chan.0.try_send(Ok(()));
152
153        self.conn.disconnect(Ok(())).await?;
154
155        // Force shutting down
156        self.task_group.cancel().await;
157        Ok(())
158    }
159
160    /// Run running the Connect Protocols for this peer connection.
161    async fn run_connect_protocols(self: &Arc<Self>) {
162        for (proto_id, constructor) in self.peer_pool().protocols.read().await.iter() {
163            trace!("peer {:?} run protocol {proto_id}", self.id());
164
165            let protocol = constructor(self.clone());
166
167            let on_failure = {
168                let this = self.clone();
169                let proto_id = proto_id.clone();
170                |result: TaskResult<Result<()>>| async move {
171                    if let TaskResult::Completed(res) = result {
172                        if res.is_err() {
173                            error!("protocol {} stopped", proto_id);
174                        }
175                        // Send a stop signal to read loop
176                        let _ = this.stop_chan.0.try_send(res);
177                    }
178                }
179            };
180
181            self.task_group.spawn(protocol.start(), on_failure);
182        }
183    }
184
185    /// Run a read loop to handle incoming messages from the peer connection.
186    async fn read_loop(&self) -> Result<()> {
187        loop {
188            let fut = select(self.stop_chan.1.recv(), self.conn.recv_inner()).await;
189            let result = match fut {
190                Either::Left(stop_signal) => {
191                    trace!("Peer {:?} received a stop signal", self.id());
192                    return stop_signal?;
193                }
194                Either::Right(result) => result,
195            };
196
197            let msg = result?;
198
199            match msg.header.command {
200                NetMsgCmd::Protocol => {
201                    let msg: ProtocolMsg = decode(&msg.payload)?.0;
202                    self.conn
203                        .emit_msg(&msg.protocol_id, &ProtocolEvent::Message(msg.payload))
204                        .await?;
205                }
206                NetMsgCmd::Shutdown => {
207                    return Err(Error::PeerShutdown);
208                }
209                command => return Err(Error::InvalidMsg(format!("Unexpected msg {:?}", command))),
210            }
211        }
212    }
213
214    /// Returns `PeerPool` pointer
215    fn peer_pool(&self) -> Arc<PeerPool> {
216        self.peer_pool.upgrade().unwrap()
217    }
218}