1use std::{collections::HashMap, fmt, sync::Arc};
2
3use async_channel::Sender;
4use bincode::Encode;
5
6use karyon_core::{
7 event::{EventEmitter, EventListener},
8 util::encode,
9};
10
11use karyon_net::Endpoint;
12
13use crate::{
14 message::{NetMsg, NetMsgCmd, ProtocolMsg, ShutdownMsg},
15 protocol::{Protocol, ProtocolEvent, ProtocolID},
16 ConnRef, Error, Result,
17};
18
19#[derive(Clone, Debug)]
21pub enum ConnDirection {
22 Inbound,
23 Outbound,
24}
25
26impl fmt::Display for ConnDirection {
27 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
28 match self {
29 ConnDirection::Inbound => write!(f, "Inbound"),
30 ConnDirection::Outbound => write!(f, "Outbound"),
31 }
32 }
33}
34
35pub struct Connection {
36 pub(crate) direction: ConnDirection,
37 conn: ConnRef,
38 disconnect_signal: Sender<Result<()>>,
39 protocol_events: Arc<EventEmitter<ProtocolID>>,
41 pub(crate) remote_endpoint: Endpoint,
42 listeners: HashMap<ProtocolID, EventListener<ProtocolID, ProtocolEvent>>,
43}
44
45impl Connection {
46 pub fn new(
47 conn: ConnRef,
48 signal: Sender<Result<()>>,
49 direction: ConnDirection,
50 remote_endpoint: Endpoint,
51 ) -> Self {
52 Self {
53 conn,
54 direction,
55 protocol_events: EventEmitter::new(),
56 disconnect_signal: signal,
57 remote_endpoint,
58 listeners: HashMap::new(),
59 }
60 }
61
62 pub async fn send<T: Encode>(&self, protocol_id: ProtocolID, msg: T) -> Result<()> {
63 let payload = encode(&msg)?;
64
65 let proto_msg = ProtocolMsg {
66 protocol_id,
67 payload: payload.to_vec(),
68 };
69
70 let msg = NetMsg::new(NetMsgCmd::Protocol, &proto_msg)?;
71 self.conn.send(msg).await
72 }
73
74 pub async fn recv<P: Protocol>(&self) -> Result<ProtocolEvent> {
75 match self.listeners.get(&P::id()) {
76 Some(l) => l.recv().await.map_err(Error::from),
77 None => Err(Error::UnsupportedProtocol(P::id())),
78 }
79 }
80
81 pub async fn register_protocol(&mut self, protocol_id: String) {
83 let listener = self.protocol_events.register(&protocol_id).await;
84 self.listeners.insert(protocol_id, listener);
85 }
86
87 pub async fn emit_msg(&self, id: &ProtocolID, event: &ProtocolEvent) -> Result<()> {
88 self.protocol_events.emit_by_topic(id, event).await?;
89 Ok(())
90 }
91
92 pub async fn recv_inner(&self) -> Result<NetMsg> {
93 self.conn.recv().await
94 }
95
96 pub async fn send_inner(&self, msg: NetMsg) -> Result<()> {
97 self.conn.send(msg).await
98 }
99
100 pub async fn disconnect(&self, res: Result<()>) -> Result<()> {
101 self.protocol_events.clear().await;
102 self.disconnect_signal.send(res).await?;
103
104 let m = NetMsg::new(NetMsgCmd::Shutdown, ShutdownMsg(0)).expect("Create shutdown message");
105 self.conn.send(m).await?;
106
107 Ok(())
108 }
109}