1use std::{future::Future, marker::PhantomData, sync::Arc};
2
3use log::{debug, error, info};
4
5use karyon_core::{
6 async_runtime::Executor,
7 async_util::{TaskGroup, TaskResult},
8 crypto::KeyPair,
9};
10use karyon_net::{
11 codec::Codec,
12 framed,
13 tcp::TcpListener,
14 tls::{ServerTlsConfig, TlsListener},
15 ByteBuffer, ByteStream, Endpoint, FramedConn,
16};
17
18use crate::{
19 codec::PeerNetMsgCodec,
20 conn_queue::ConnQueue,
21 monitor::{ConnectionKind, Monitor},
22 peer::ConnDirection,
23 slots::ConnectionSlots,
24 tls_config::{peer_id_from_certs, tls_server_config},
25 Error, Result,
26};
27
28enum StreamListener {
31 Tcp(TcpListener),
32 Tls(Box<TlsListener>),
33}
34
35impl StreamListener {
36 async fn accept(&self) -> Result<Box<dyn ByteStream>> {
37 match self {
38 Self::Tcp(l) => Ok(l.accept().await?),
39 Self::Tls(l) => Ok(l.accept().await?),
40 }
41 }
42
43 fn local_endpoint(&self) -> Result<Endpoint> {
44 match self {
45 Self::Tcp(l) => Ok(l.local_endpoint()?),
46 Self::Tls(l) => Ok(l.local_endpoint()?),
47 }
48 }
49}
50
51#[cfg(feature = "quic")]
52use karyon_net::{quic, StreamMux};
53
54pub struct Listener<C: Codec<ByteBuffer> + Default + Clone> {
59 key_pair: KeyPair,
60 task_group: TaskGroup,
61 connection_slots: Arc<ConnectionSlots>,
62 conn_queue: Option<Arc<ConnQueue>>,
63 monitor: Arc<Monitor>,
64 _codec: PhantomData<C>,
65}
66
67impl<C> Listener<C>
68where
69 C: Codec<ByteBuffer, Error = karyon_net::Error> + Default + Clone + Send + Sync + 'static,
70{
71 pub fn new(
73 key_pair: &KeyPair,
74 connection_slots: Arc<ConnectionSlots>,
75 monitor: Arc<Monitor>,
76 ex: Executor,
77 ) -> Arc<Self> {
78 Arc::new(Self {
79 key_pair: key_pair.clone(),
80 connection_slots,
81 conn_queue: None,
82 task_group: TaskGroup::with_executor(ex),
83 monitor,
84 _codec: PhantomData,
85 })
86 }
87
88 pub async fn start_with_callback<Fut>(
90 self: &Arc<Self>,
91 endpoint: Endpoint,
92 callback: impl FnOnce(FramedConn<C>) -> Fut + Clone + Send + 'static,
93 ) -> Result<Endpoint>
94 where
95 Fut: Future<Output = Result<()>> + Send + 'static,
96 {
97 let listener = match self.listen(&endpoint).await {
98 Ok(l) => {
99 self.monitor
100 .notify(ConnectionKind::Listening(endpoint.clone()))
101 .await;
102 l
103 }
104 Err(err) => {
105 error!("Failed to listen on {endpoint}: {err}");
106 self.monitor
107 .notify(ConnectionKind::ListenFailed(endpoint))
108 .await;
109 return Err(err);
110 }
111 };
112
113 let resolved = listener.local_endpoint()?;
114 info!("Start listening on {resolved}");
115
116 self.task_group.spawn(
117 {
118 let this = self.clone();
119 async move { this.listen_loop_callback(listener, callback).await }
120 },
121 |res: TaskResult<()>| async move {
122 debug!("Listener callback loop ended: {res}");
123 },
124 );
125 Ok(resolved)
126 }
127
128 pub async fn shutdown(&self) {
129 self.task_group.cancel().await;
130 }
131
132 async fn listen_loop_callback<Fut>(
134 self: Arc<Self>,
135 listener: StreamListener,
136 callback: impl FnOnce(FramedConn<C>) -> Fut + Clone + Send + 'static,
137 ) where
138 Fut: Future<Output = Result<()>> + Send + 'static,
139 {
140 loop {
141 self.connection_slots.wait_for_slot().await;
142 let result = listener.accept().await;
143
144 let conn: FramedConn<C> = match result {
145 Ok(stream) => framed(stream, C::default()),
146 Err(err) => {
147 error!("Failed to accept connection: {err}");
148 self.monitor.notify(ConnectionKind::AcceptFailed).await;
149 continue;
150 }
151 };
152
153 let endpoint = match conn.peer_endpoint() {
154 Some(ep) => ep,
155 None => {
156 self.monitor.notify(ConnectionKind::AcceptFailed).await;
157 error!("Failed to get peer endpoint");
158 continue;
159 }
160 };
161
162 self.monitor
163 .notify(ConnectionKind::Accepted(endpoint.clone()))
164 .await;
165 self.connection_slots.add();
166
167 let on_disconnect = {
168 let this = self.clone();
169 |res| async move {
170 if let TaskResult::Completed(Err(err)) = res {
171 debug!("Inbound connection dropped: {err}");
172 }
173 this.monitor
174 .notify(ConnectionKind::Disconnected(endpoint))
175 .await;
176 this.connection_slots.remove().await;
177 }
178 };
179
180 let callback = callback.clone();
181 self.task_group.spawn(callback(conn), on_disconnect);
182 }
183 }
184
185 async fn listen(&self, endpoint: &Endpoint) -> Result<StreamListener> {
187 match endpoint {
188 Endpoint::Tcp(..) => {
189 let listener = TcpListener::bind(endpoint, Default::default()).await?;
190 Ok(StreamListener::Tcp(listener))
191 }
192 Endpoint::Tls(..) => {
193 let tls_config = ServerTlsConfig {
194 server_config: tls_server_config(&self.key_pair)?,
195 };
196 let tcp_listener = TcpListener::bind(endpoint, Default::default()).await?;
197 let listener = TlsListener::new(tcp_listener, tls_config);
198 Ok(StreamListener::Tls(Box::new(listener)))
199 }
200 _ => Err(Error::UnsupportedEndpoint(endpoint.to_string())),
201 }
202 }
203}
204
205impl Listener<PeerNetMsgCodec> {
209 pub fn new_with_queue(
211 key_pair: &KeyPair,
212 connection_slots: Arc<ConnectionSlots>,
213 conn_queue: Arc<ConnQueue>,
214 monitor: Arc<Monitor>,
215 ex: Executor,
216 ) -> Arc<Self> {
217 Arc::new(Self {
218 key_pair: key_pair.clone(),
219 connection_slots,
220 conn_queue: Some(conn_queue),
221 task_group: TaskGroup::with_executor(ex),
222 monitor,
223 _codec: PhantomData,
224 })
225 }
226
227 pub async fn start(self: &Arc<Self>, endpoint: Endpoint) -> Result<Endpoint> {
229 #[cfg(feature = "quic")]
230 if endpoint.is_quic() {
231 return self.start_quic(endpoint).await;
232 }
233
234 let listener = match self.listen(&endpoint).await {
235 Ok(l) => {
236 self.monitor
237 .notify(ConnectionKind::Listening(endpoint.clone()))
238 .await;
239 l
240 }
241 Err(err) => {
242 error!("Failed to listen on {endpoint}: {err}");
243 self.monitor
244 .notify(ConnectionKind::ListenFailed(endpoint))
245 .await;
246 return Err(err);
247 }
248 };
249
250 let resolved = listener.local_endpoint()?;
251 info!("Start listening on {resolved}");
252
253 self.task_group.spawn(
254 {
255 let this = self.clone();
256 async move { this.listen_loop(listener).await }
257 },
258 |_| async {},
259 );
260 Ok(resolved)
261 }
262
263 async fn listen_loop(self: Arc<Self>, listener: StreamListener) {
265 let conn_queue = self
266 .conn_queue
267 .as_ref()
268 .expect("listen_loop requires ConnQueue")
269 .clone();
270
271 loop {
272 self.connection_slots.wait_for_slot().await;
273 let result = listener.accept().await;
274
275 let (conn, vpid) = match result {
276 Ok(stream) => {
277 let vpid = stream
279 .peer_certificates()
280 .as_deref()
281 .and_then(peer_id_from_certs);
282 let conn: FramedConn<PeerNetMsgCodec> = framed(stream, PeerNetMsgCodec::new());
283 (conn, vpid)
284 }
285 Err(err) => {
286 error!("Failed to accept connection: {err}");
287 self.monitor.notify(ConnectionKind::AcceptFailed).await;
288 continue;
289 }
290 };
291
292 let endpoint = match conn.peer_endpoint() {
293 Some(ep) => ep,
294 None => {
295 self.monitor.notify(ConnectionKind::AcceptFailed).await;
296 error!("Failed to get peer endpoint");
297 continue;
298 }
299 };
300
301 self.monitor
302 .notify(ConnectionKind::Accepted(endpoint.clone()))
303 .await;
304 self.connection_slots.add();
305
306 let on_disconnect = {
307 let this = self.clone();
308 |res: TaskResult<Result<()>>| async move {
309 if let TaskResult::Completed(Err(err)) = res {
310 debug!("Inbound connection dropped: {err}");
311 }
312 this.monitor
313 .notify(ConnectionKind::Disconnected(endpoint))
314 .await;
315 this.connection_slots.remove().await;
316 }
317 };
318
319 let cq = conn_queue.clone();
320 self.task_group.spawn(
321 async move {
322 cq.handle(conn, ConnDirection::Inbound, vpid).await?;
323 Ok(())
324 },
325 on_disconnect,
326 );
327 }
328 }
329
330 #[cfg(feature = "quic")]
332 async fn start_quic(self: &Arc<Self>, endpoint: Endpoint) -> Result<Endpoint> {
333 let rustls_config = tls_server_config(&self.key_pair)?;
334 let server_config = quic::ServerQuicConfig::from_rustls(rustls_config);
335
336 let quic_endpoint = match quic::QuicEndpoint::listen(&endpoint, server_config).await {
337 Ok(ep) => {
338 self.monitor
339 .notify(ConnectionKind::Listening(endpoint.clone()))
340 .await;
341 ep
342 }
343 Err(err) => {
344 error!("Failed to listen on {endpoint}: {err}");
345 self.monitor
346 .notify(ConnectionKind::ListenFailed(endpoint))
347 .await;
348 return Err(err.into());
349 }
350 };
351
352 let resolved: Endpoint = quic_endpoint.local_endpoint().map_err(Error::from)?;
353 info!("Start listening on {resolved}");
354
355 self.task_group.spawn(
356 {
357 let this = self.clone();
358 async move { this.listen_loop_quic(quic_endpoint).await }
359 },
360 |res: TaskResult<()>| async move {
361 debug!("QUIC listen loop ended: {res}");
362 },
363 );
364
365 Ok(resolved)
366 }
367
368 #[cfg(feature = "quic")]
370 async fn listen_loop_quic(self: Arc<Self>, quic_endpoint: quic::QuicEndpoint) {
371 loop {
372 self.connection_slots.wait_for_slot().await;
373
374 let quic_conn = match quic_endpoint.accept().await {
375 Ok(c) => c,
376 Err(err) => {
377 error!("Failed to accept QUIC conn: {err}");
378 self.monitor.notify(ConnectionKind::AcceptFailed).await;
379 continue;
380 }
381 };
382
383 let peer_ep = match quic_conn.peer_endpoint() {
384 Ok(ep) => ep,
385 Err(err) => {
386 error!("Failed to get peer endpoint: {err}");
387 self.monitor.notify(ConnectionKind::AcceptFailed).await;
388 continue;
389 }
390 };
391
392 self.monitor
393 .notify(ConnectionKind::Accepted(peer_ep.clone()))
394 .await;
395
396 let vpid = quic_conn
397 .peer_certificates()
398 .as_deref()
399 .and_then(peer_id_from_certs);
400
401 let stream = match quic_conn.accept_stream().await {
402 Ok(s) => s,
403 Err(err) => {
404 error!("Failed to accept handshake stream: {err}");
405 self.monitor.notify(ConnectionKind::AcceptFailed).await;
406 continue;
407 }
408 };
409
410 let conn: FramedConn<PeerNetMsgCodec> = framed(stream, PeerNetMsgCodec::new());
411
412 self.connection_slots.add();
413
414 let on_disconnect = {
415 let this = self.clone();
416 |res: TaskResult<Result<()>>| async move {
417 if let TaskResult::Completed(Err(err)) = res {
418 debug!("Inbound QUIC conn dropped: {err}");
419 }
420 this.monitor
421 .notify(ConnectionKind::Disconnected(peer_ep))
422 .await;
423 this.connection_slots.remove().await;
424 }
425 };
426
427 let conn_queue = self
428 .conn_queue
429 .as_ref()
430 .expect("QUIC listener requires ConnQueue")
431 .clone();
432 self.task_group.spawn(
433 async move {
434 conn_queue
435 .handle_quic(conn, quic_conn, ConnDirection::Inbound, vpid)
436 .await?;
437 Ok(())
438 },
439 on_disconnect,
440 );
441 }
442 }
443}