karyon_p2p/
connection.rs

1use std::{collections::HashMap, fmt, sync::Arc};
2
3use async_channel::Sender;
4use bincode::Encode;
5
6use karyon_core::{
7    event::{EventEmitter, EventListener},
8    util::encode,
9};
10
11use karyon_net::Endpoint;
12
13use crate::{
14    message::{NetMsg, NetMsgCmd, ProtocolMsg, ShutdownMsg},
15    protocol::{Protocol, ProtocolEvent, ProtocolID},
16    ConnRef, Error, Result,
17};
18
19/// Defines the direction of a network connection.
20#[derive(Clone, Debug)]
21pub enum ConnDirection {
22    Inbound,
23    Outbound,
24}
25
26impl fmt::Display for ConnDirection {
27    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
28        match self {
29            ConnDirection::Inbound => write!(f, "Inbound"),
30            ConnDirection::Outbound => write!(f, "Outbound"),
31        }
32    }
33}
34
35pub struct Connection {
36    pub(crate) direction: ConnDirection,
37    conn: ConnRef,
38    disconnect_signal: Sender<Result<()>>,
39    /// `EventEmitter` responsible for sending events to the registered protocols.
40    protocol_events: Arc<EventEmitter<ProtocolID>>,
41    pub(crate) remote_endpoint: Endpoint,
42    listeners: HashMap<ProtocolID, EventListener<ProtocolID, ProtocolEvent>>,
43}
44
45impl Connection {
46    pub fn new(
47        conn: ConnRef,
48        signal: Sender<Result<()>>,
49        direction: ConnDirection,
50        remote_endpoint: Endpoint,
51    ) -> Self {
52        Self {
53            conn,
54            direction,
55            protocol_events: EventEmitter::new(),
56            disconnect_signal: signal,
57            remote_endpoint,
58            listeners: HashMap::new(),
59        }
60    }
61
62    pub async fn send<T: Encode>(&self, protocol_id: ProtocolID, msg: T) -> Result<()> {
63        let payload = encode(&msg)?;
64
65        let proto_msg = ProtocolMsg {
66            protocol_id,
67            payload: payload.to_vec(),
68        };
69
70        let msg = NetMsg::new(NetMsgCmd::Protocol, &proto_msg)?;
71        self.conn.send(msg).await
72    }
73
74    pub async fn recv<P: Protocol>(&self) -> Result<ProtocolEvent> {
75        match self.listeners.get(&P::id()) {
76            Some(l) => l.recv().await.map_err(Error::from),
77            None => Err(Error::UnsupportedProtocol(P::id())),
78        }
79    }
80
81    /// Registers a listener for the given Protocol `P`.
82    pub async fn register_protocol(&mut self, protocol_id: String) {
83        let listener = self.protocol_events.register(&protocol_id).await;
84        self.listeners.insert(protocol_id, listener);
85    }
86
87    pub async fn emit_msg(&self, id: &ProtocolID, event: &ProtocolEvent) -> Result<()> {
88        self.protocol_events.emit_by_topic(id, event).await?;
89        Ok(())
90    }
91
92    pub async fn recv_inner(&self) -> Result<NetMsg> {
93        self.conn.recv().await
94    }
95
96    pub async fn send_inner(&self, msg: NetMsg) -> Result<()> {
97        self.conn.send(msg).await
98    }
99
100    pub async fn disconnect(&self, res: Result<()>) -> Result<()> {
101        self.protocol_events.clear().await;
102        self.disconnect_signal.send(res).await?;
103
104        let m = NetMsg::new(NetMsgCmd::Shutdown, ShutdownMsg(0)).expect("Create shutdown message");
105        self.conn.send(m).await?;
106
107        Ok(())
108    }
109}