1use std::{collections::HashMap, fmt, sync::Arc};
2
3use async_channel::Sender;
4use bincode::Encode;
5
6use karyon_core::util::encode;
7use karyon_eventemitter::{EventEmitter, EventListener};
8
9use karyon_net::Endpoint;
10
11use crate::{
12 message::{NetMsg, NetMsgCmd, ProtocolMsg, ShutdownMsg},
13 protocol::{Protocol, ProtocolEvent, ProtocolID},
14 ConnRef, Error, Result,
15};
16
17#[derive(Clone, Debug)]
19pub enum ConnDirection {
20 Inbound,
21 Outbound,
22}
23
24impl fmt::Display for ConnDirection {
25 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
26 match self {
27 ConnDirection::Inbound => write!(f, "Inbound"),
28 ConnDirection::Outbound => write!(f, "Outbound"),
29 }
30 }
31}
32
33pub struct Connection {
34 pub(crate) direction: ConnDirection,
35 conn: ConnRef,
36 disconnect_signal: Sender<Result<()>>,
37 protocol_events: Arc<EventEmitter<ProtocolID>>,
39 pub(crate) remote_endpoint: Endpoint,
40 listeners: HashMap<ProtocolID, EventListener<ProtocolID, ProtocolEvent>>,
41}
42
43impl Connection {
44 pub fn new(
45 conn: ConnRef,
46 signal: Sender<Result<()>>,
47 direction: ConnDirection,
48 remote_endpoint: Endpoint,
49 ) -> Self {
50 Self {
51 conn,
52 direction,
53 protocol_events: EventEmitter::new(),
54 disconnect_signal: signal,
55 remote_endpoint,
56 listeners: HashMap::new(),
57 }
58 }
59
60 pub async fn send<T: Encode>(&self, protocol_id: ProtocolID, msg: T) -> Result<()> {
61 let payload = encode(&msg)?;
62
63 let proto_msg = ProtocolMsg {
64 protocol_id,
65 payload: payload.to_vec(),
66 };
67
68 let msg = NetMsg::new(NetMsgCmd::Protocol, &proto_msg)?;
69 self.conn.send(msg).await
70 }
71
72 pub async fn recv<P: Protocol>(&self) -> Result<ProtocolEvent> {
73 match self.listeners.get(&P::id()) {
74 Some(l) => l.recv().await.map_err(Error::from),
75 None => Err(Error::UnsupportedProtocol(P::id())),
76 }
77 }
78
79 pub async fn register_protocol(&mut self, protocol_id: String) {
81 let listener = self.protocol_events.register(&protocol_id);
82 self.listeners.insert(protocol_id, listener);
83 }
84
85 pub async fn emit_msg(&self, id: &ProtocolID, event: &ProtocolEvent) -> Result<()> {
86 self.protocol_events.emit_by_topic(id, event).await?;
87 Ok(())
88 }
89
90 pub async fn recv_inner(&self) -> Result<NetMsg> {
91 self.conn.recv().await
92 }
93
94 pub async fn send_inner(&self, msg: NetMsg) -> Result<()> {
95 self.conn.send(msg).await
96 }
97
98 pub async fn disconnect(&self, res: Result<()>) -> Result<()> {
99 self.protocol_events.clear();
100 self.disconnect_signal.send(res).await?;
101
102 let m = NetMsg::new(NetMsgCmd::Shutdown, ShutdownMsg(0)).expect("Create shutdown message");
103 self.conn.send(m).await?;
104
105 Ok(())
106 }
107}