Skip to main content

karyon_p2p/peer/
mod.rs

1mod connection;
2mod peer_id;
3
4use std::{
5    collections::HashSet,
6    fmt,
7    sync::{Arc, Weak},
8};
9
10use async_channel::{Receiver, Sender};
11use log::{error, trace};
12
13use karyon_core::{
14    async_runtime::Executor,
15    async_util::{TaskGroup, TaskResult},
16};
17
18use crate::{
19    conn_queue::QueuedConn,
20    endpoint::Endpoint,
21    peer_pool::PeerPool,
22    protocol::{PeerConn, ProtocolEvent, ProtocolID},
23    Config, Result,
24};
25
26pub use peer_id::PeerID;
27
28use connection::PeerConnection;
29
30#[derive(Clone, Debug)]
31pub enum ConnDirection {
32    Inbound,
33    Outbound,
34}
35
36impl fmt::Display for ConnDirection {
37    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38        match self {
39            ConnDirection::Inbound => write!(f, "Inbound"),
40            ConnDirection::Outbound => write!(f, "Outbound"),
41        }
42    }
43}
44
45/// A connected peer. Holds a `PeerConnection` that hides the wire
46/// shape (single framed pipe vs. per-protocol streams).
47pub struct Peer {
48    own_id: PeerID,
49    id: PeerID,
50    peer_pool: Weak<PeerPool>,
51
52    direction: ConnDirection,
53    remote_endpoint: Endpoint,
54
55    connection: Arc<dyn PeerConnection>,
56    disconnect_signal: Sender<Result<()>>,
57
58    negotiated_protocols: HashSet<ProtocolID>,
59    stop_chan: (Sender<Result<()>>, Receiver<Result<()>>),
60    config: Arc<Config>,
61    executor: Executor,
62    task_group: TaskGroup,
63}
64
65impl Peer {
66    pub async fn send(&self, proto_id: ProtocolID, msg: Vec<u8>) -> Result<()> {
67        self.connection.send(&proto_id, msg).await
68    }
69
70    pub async fn recv(&self, proto_id: &ProtocolID) -> Result<ProtocolEvent> {
71        self.connection.recv(proto_id).await
72    }
73
74    pub async fn broadcast(&self, proto_id: &ProtocolID, msg: Vec<u8>) {
75        self.peer_pool().broadcast(proto_id, msg).await;
76    }
77
78    pub fn id(&self) -> &PeerID {
79        &self.id
80    }
81
82    pub fn own_id(&self) -> &PeerID {
83        &self.own_id
84    }
85
86    pub fn config(&self) -> Arc<Config> {
87        self.config.clone()
88    }
89
90    pub fn executor(&self) -> Executor {
91        self.executor.clone()
92    }
93
94    pub fn remote_endpoint(&self) -> &Endpoint {
95        &self.remote_endpoint
96    }
97
98    pub fn is_inbound(&self) -> bool {
99        matches!(self.direction, ConnDirection::Inbound)
100    }
101
102    pub fn direction(&self) -> &ConnDirection {
103        &self.direction
104    }
105
106    pub fn negotiated_protocols(&self) -> &HashSet<ProtocolID> {
107        &self.negotiated_protocols
108    }
109
110    pub(crate) async fn run(self: Arc<Self>) -> Result<()> {
111        self.run_connect_protocols().await;
112        let stop_signal = self.stop_chan.1.recv().await?;
113        stop_signal
114    }
115
116    pub(crate) async fn shutdown(self: &Arc<Self>) -> Result<()> {
117        trace!("peer {} shutting down", self.id);
118
119        let _ = self.connection.shutdown().await;
120        let _ = self.stop_chan.0.try_send(Ok(()));
121
122        let _ = self.disconnect_signal.send(Ok(())).await;
123        self.task_group.cancel().await;
124        Ok(())
125    }
126
127    async fn run_connect_protocols(self: &Arc<Self>) {
128        for (proto_id, constructor) in self.peer_pool().protocols.read().await.iter() {
129            if !self.negotiated_protocols.contains(proto_id) {
130                trace!("peer {} skip protocol {proto_id} (not negotiated)", self.id);
131                continue;
132            }
133            trace!("peer {} run protocol {proto_id}", self.id);
134
135            let peer_conn = PeerConn::new(self.clone(), proto_id.clone());
136            let protocol = match constructor(peer_conn) {
137                Ok(p) => p,
138                Err(err) => {
139                    error!("Failed to build protocol {proto_id}: {err}");
140                    continue;
141                }
142            };
143
144            let on_failure = {
145                let this = self.clone();
146                let proto_id = proto_id.clone();
147                |result: TaskResult<Result<()>>| async move {
148                    if let TaskResult::Completed(res) = result {
149                        if res.is_err() {
150                            error!("protocol {proto_id} stopped");
151                        }
152                        let _ = this.stop_chan.0.try_send(res);
153                    }
154                }
155            };
156
157            self.task_group.spawn(protocol.start(), on_failure);
158        }
159    }
160
161    fn peer_pool(&self) -> Arc<PeerPool> {
162        self.peer_pool.upgrade().unwrap()
163    }
164}
165
166impl Peer {
167    pub(crate) async fn new(
168        peer_pool: Arc<PeerPool>,
169        queued: QueuedConn,
170        id: PeerID,
171        negotiated_protocols: HashSet<ProtocolID>,
172        protocol_ids: impl IntoIterator<Item = ProtocolID> + Clone,
173    ) -> Result<Arc<Self>> {
174        let own_id = peer_pool.id.clone();
175        let config = peer_pool.config.clone();
176        let executor = peer_pool.executor.clone();
177        let task_group = TaskGroup::with_executor(executor.clone());
178        let stop_chan = async_channel::bounded::<Result<()>>(1);
179
180        let remote_endpoint = queued.remote_endpoint.clone();
181        let direction = queued.direction.clone();
182        let disconnect_signal = queued.disconnect_signal.clone();
183
184        let connection = connection::from_queued(
185            queued,
186            &negotiated_protocols,
187            protocol_ids,
188            &task_group,
189            stop_chan.0.clone(),
190        )
191        .await?;
192
193        let peer_pool_weak = Arc::downgrade(&peer_pool);
194        Ok(Arc::new(Peer {
195            own_id,
196            id,
197            peer_pool: peer_pool_weak,
198            direction,
199            remote_endpoint,
200            connection,
201            disconnect_signal,
202            negotiated_protocols,
203            stop_chan,
204            config,
205            executor,
206            task_group,
207        }))
208    }
209}