karyon_p2p/
connection.rs

1use std::{collections::HashMap, fmt, sync::Arc};
2
3use async_channel::Sender;
4use bincode::Encode;
5
6use karyon_core::util::encode;
7use karyon_eventemitter::{EventEmitter, EventListener};
8
9use karyon_net::Endpoint;
10
11use crate::{
12    message::{NetMsg, NetMsgCmd, ProtocolMsg, ShutdownMsg},
13    protocol::{Protocol, ProtocolEvent, ProtocolID},
14    ConnRef, Error, Result,
15};
16
17/// Defines the direction of a network connection.
18#[derive(Clone, Debug)]
19pub enum ConnDirection {
20    Inbound,
21    Outbound,
22}
23
24impl fmt::Display for ConnDirection {
25    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
26        match self {
27            ConnDirection::Inbound => write!(f, "Inbound"),
28            ConnDirection::Outbound => write!(f, "Outbound"),
29        }
30    }
31}
32
33pub struct Connection {
34    pub(crate) direction: ConnDirection,
35    conn: ConnRef,
36    disconnect_signal: Sender<Result<()>>,
37    /// `EventEmitter` responsible for sending events to the registered protocols.
38    protocol_events: Arc<EventEmitter<ProtocolID>>,
39    pub(crate) remote_endpoint: Endpoint,
40    listeners: HashMap<ProtocolID, EventListener<ProtocolID, ProtocolEvent>>,
41}
42
43impl Connection {
44    pub fn new(
45        conn: ConnRef,
46        signal: Sender<Result<()>>,
47        direction: ConnDirection,
48        remote_endpoint: Endpoint,
49    ) -> Self {
50        Self {
51            conn,
52            direction,
53            protocol_events: EventEmitter::new(),
54            disconnect_signal: signal,
55            remote_endpoint,
56            listeners: HashMap::new(),
57        }
58    }
59
60    pub async fn send<T: Encode>(&self, protocol_id: ProtocolID, msg: T) -> Result<()> {
61        let payload = encode(&msg)?;
62
63        let proto_msg = ProtocolMsg {
64            protocol_id,
65            payload: payload.to_vec(),
66        };
67
68        let msg = NetMsg::new(NetMsgCmd::Protocol, &proto_msg)?;
69        self.conn.send(msg).await
70    }
71
72    pub async fn recv<P: Protocol>(&self) -> Result<ProtocolEvent> {
73        match self.listeners.get(&P::id()) {
74            Some(l) => l.recv().await.map_err(Error::from),
75            None => Err(Error::UnsupportedProtocol(P::id())),
76        }
77    }
78
79    /// Registers a listener for the given Protocol `P`.
80    pub async fn register_protocol(&mut self, protocol_id: String) {
81        let listener = self.protocol_events.register(&protocol_id);
82        self.listeners.insert(protocol_id, listener);
83    }
84
85    pub async fn emit_msg(&self, id: &ProtocolID, event: &ProtocolEvent) -> Result<()> {
86        self.protocol_events.emit_by_topic(id, event).await?;
87        Ok(())
88    }
89
90    pub async fn recv_inner(&self) -> Result<NetMsg> {
91        self.conn.recv().await
92    }
93
94    pub async fn send_inner(&self, msg: NetMsg) -> Result<()> {
95        self.conn.send(msg).await
96    }
97
98    pub async fn disconnect(&self, res: Result<()>) -> Result<()> {
99        self.protocol_events.clear();
100        self.disconnect_signal.send(res).await?;
101
102        let m = NetMsg::new(NetMsgCmd::Shutdown, ShutdownMsg(0)).expect("Create shutdown message");
103        self.conn.send(m).await?;
104
105        Ok(())
106    }
107}