1use std::{future::Future, sync::Arc};
2
3use log::{debug, error, info};
4
5use karyon_core::{
6 async_runtime::Executor,
7 async_util::{TaskGroup, TaskResult},
8 crypto::KeyPair,
9};
10
11use karyon_net::{tcp, tls, Endpoint};
12
13use crate::{
14 codec::NetMsgCodec,
15 message::NetMsg,
16 monitor::{ConnEvent, Monitor},
17 slots::ConnectionSlots,
18 tls_config::tls_server_config,
19 ConnRef, Error, ListenerRef, Result,
20};
21
22pub struct Listener {
24 key_pair: KeyPair,
26
27 task_group: TaskGroup,
29
30 connection_slots: Arc<ConnectionSlots>,
32
33 enable_tls: bool,
35
36 monitor: Arc<Monitor>,
38}
39
40impl Listener {
41 pub fn new(
43 key_pair: &KeyPair,
44 connection_slots: Arc<ConnectionSlots>,
45 enable_tls: bool,
46 monitor: Arc<Monitor>,
47 ex: Executor,
48 ) -> Arc<Self> {
49 Arc::new(Self {
50 key_pair: key_pair.clone(),
51 connection_slots,
52 task_group: TaskGroup::with_executor(ex),
53 enable_tls,
54 monitor,
55 })
56 }
57
58 pub async fn start<Fut>(
64 self: &Arc<Self>,
65 endpoint: Endpoint,
66 callback: impl FnOnce(ConnRef) -> Fut + Clone + Send + 'static,
68 ) -> Result<Endpoint>
69 where
70 Fut: Future<Output = Result<()>> + Send + 'static,
71 {
72 let listener = match self.listen(&endpoint).await {
73 Ok(listener) => {
74 self.monitor
75 .notify(ConnEvent::Listening(endpoint.clone()))
76 .await;
77 listener
78 }
79 Err(err) => {
80 error!("Failed to listen on {endpoint}: {err}");
81 self.monitor.notify(ConnEvent::ListenFailed(endpoint)).await;
82 return Err(err);
83 }
84 };
85
86 let resolved_endpoint = listener.local_endpoint()?;
87
88 info!("Start listening on {resolved_endpoint}");
89
90 self.task_group.spawn(
91 {
92 let this = self.clone();
93 async move { this.listen_loop(listener, callback).await }
94 },
95 |_| async {},
96 );
97 Ok(resolved_endpoint)
98 }
99
100 pub async fn shutdown(&self) {
102 self.task_group.cancel().await;
103 }
104
105 async fn listen_loop<Fut>(
106 self: Arc<Self>,
107 listener: karyon_net::Listener<NetMsg, Error>,
108 callback: impl FnOnce(ConnRef) -> Fut + Clone + Send + 'static,
109 ) where
110 Fut: Future<Output = Result<()>> + Send + 'static,
111 {
112 loop {
113 self.connection_slots.wait_for_slot().await;
115 let result = listener.accept().await;
116
117 let (conn, endpoint) = match result {
118 Ok(c) => {
119 let endpoint = match c.peer_endpoint() {
120 Ok(ep) => ep,
121 Err(err) => {
122 self.monitor.notify(ConnEvent::AcceptFailed).await;
123 error!("Failed to accept a new connection: {err}");
124 continue;
125 }
126 };
127
128 self.monitor
129 .notify(ConnEvent::Accepted(endpoint.clone()))
130 .await;
131 (c, endpoint)
132 }
133 Err(err) => {
134 error!("Failed to accept a new connection: {err}");
135 self.monitor.notify(ConnEvent::AcceptFailed).await;
136 continue;
137 }
138 };
139
140 self.connection_slots.add();
141
142 let on_disconnect = {
143 let this = self.clone();
144 |res| async move {
145 if let TaskResult::Completed(Err(err)) = res {
146 debug!("Inbound connection dropped: {err}");
147 }
148 this.monitor.notify(ConnEvent::Disconnected(endpoint)).await;
149 this.connection_slots.remove().await;
150 }
151 };
152
153 let callback = callback.clone();
154 self.task_group.spawn(callback(conn), on_disconnect);
155 }
156 }
157
158 async fn listen(&self, endpoint: &Endpoint) -> Result<ListenerRef> {
159 if self.enable_tls {
160 if !endpoint.is_tcp() && !endpoint.is_tls() {
161 return Err(Error::UnsupportedEndpoint(endpoint.to_string()));
162 }
163
164 let tls_config = tls::ServerTlsConfig {
165 tcp_config: Default::default(),
166 server_config: tls_server_config(&self.key_pair)?,
167 };
168 let l = tls::listen(endpoint, tls_config, NetMsgCodec::new()).await?;
169 Ok(Box::new(l))
170 } else {
171 if !endpoint.is_tcp() {
172 return Err(Error::UnsupportedEndpoint(endpoint.to_string()));
173 }
174
175 let l = tcp::listen(endpoint, tcp::TcpConfig::default(), NetMsgCodec::new()).await?;
176 Ok(Box::new(l))
177 }
178 }
179}