Skip to main content

karyon_p2p/
listener.rs

1use std::{future::Future, marker::PhantomData, sync::Arc};
2
3use log::{debug, error, info};
4
5use karyon_core::{
6    async_runtime::Executor,
7    async_util::{TaskGroup, TaskResult},
8    crypto::KeyPair,
9};
10use karyon_net::{
11    codec::Codec,
12    framed,
13    tcp::TcpListener,
14    tls::{ServerTlsConfig, TlsListener},
15    ByteBuffer, ByteStream, Endpoint, FramedConn,
16};
17
18use crate::{
19    codec::PeerNetMsgCodec,
20    conn_queue::ConnQueue,
21    monitor::{ConnectionKind, Monitor},
22    peer::ConnDirection,
23    slots::ConnectionSlots,
24    tls_config::{peer_id_from_certs, tls_server_config},
25    Error, Result,
26};
27
28/// Listener for byte-stream transports (TCP, TLS). Each `accept` yields
29/// a single `Box<dyn ByteStream>`. QUIC uses a separate `StreamMux` path.
30enum StreamListener {
31    Tcp(TcpListener),
32    Tls(Box<TlsListener>),
33}
34
35impl StreamListener {
36    async fn accept(&self) -> Result<Box<dyn ByteStream>> {
37        match self {
38            Self::Tcp(l) => Ok(l.accept().await?),
39            Self::Tls(l) => Ok(l.accept().await?),
40        }
41    }
42
43    fn local_endpoint(&self) -> Result<Endpoint> {
44        match self {
45            Self::Tcp(l) => Ok(l.local_endpoint()?),
46            Self::Tls(l) => Ok(l.local_endpoint()?),
47        }
48    }
49}
50
51#[cfg(feature = "quic")]
52use karyon_net::{quic, StreamMux};
53
54/// Creates inbound connections with other peers. Generic over the
55/// codec applied to the framed accepted streams so the same accept
56/// machinery serves the peer data-plane (`PeerNetMsgCodec`) and the
57/// kademlia lookup-plane (`KadNetMsgCodec`).
58pub struct Listener<C: Codec<ByteBuffer> + Default + Clone> {
59    key_pair: KeyPair,
60    task_group: TaskGroup,
61    connection_slots: Arc<ConnectionSlots>,
62    conn_queue: Option<Arc<ConnQueue>>,
63    monitor: Arc<Monitor>,
64    _codec: PhantomData<C>,
65}
66
67impl<C> Listener<C>
68where
69    C: Codec<ByteBuffer, Error = karyon_net::Error> + Default + Clone + Send + Sync + 'static,
70{
71    /// Create a new Listener (no auto-queue; use `start_with_callback`).
72    pub fn new(
73        key_pair: &KeyPair,
74        connection_slots: Arc<ConnectionSlots>,
75        monitor: Arc<Monitor>,
76        ex: Executor,
77    ) -> Arc<Self> {
78        Arc::new(Self {
79            key_pair: key_pair.clone(),
80            connection_slots,
81            conn_queue: None,
82            task_group: TaskGroup::with_executor(ex),
83            monitor,
84            _codec: PhantomData,
85        })
86    }
87
88    /// Start with a user-provided callback for each connection.
89    pub async fn start_with_callback<Fut>(
90        self: &Arc<Self>,
91        endpoint: Endpoint,
92        callback: impl FnOnce(FramedConn<C>) -> Fut + Clone + Send + 'static,
93    ) -> Result<Endpoint>
94    where
95        Fut: Future<Output = Result<()>> + Send + 'static,
96    {
97        let listener = match self.listen(&endpoint).await {
98            Ok(l) => {
99                self.monitor
100                    .notify(ConnectionKind::Listening(endpoint.clone()))
101                    .await;
102                l
103            }
104            Err(err) => {
105                error!("Failed to listen on {endpoint}: {err}");
106                self.monitor
107                    .notify(ConnectionKind::ListenFailed(endpoint))
108                    .await;
109                return Err(err);
110            }
111        };
112
113        let resolved = listener.local_endpoint()?;
114        info!("Start listening on {resolved}");
115
116        self.task_group.spawn(
117            {
118                let this = self.clone();
119                async move { this.listen_loop_callback(listener, callback).await }
120            },
121            |res: TaskResult<()>| async move {
122                debug!("Listener callback loop ended: {res}");
123            },
124        );
125        Ok(resolved)
126    }
127
128    pub async fn shutdown(&self) {
129        self.task_group.cancel().await;
130    }
131
132    /// Accept loop (callback mode).
133    async fn listen_loop_callback<Fut>(
134        self: Arc<Self>,
135        listener: StreamListener,
136        callback: impl FnOnce(FramedConn<C>) -> Fut + Clone + Send + 'static,
137    ) where
138        Fut: Future<Output = Result<()>> + Send + 'static,
139    {
140        loop {
141            self.connection_slots.wait_for_slot().await;
142            let result = listener.accept().await;
143
144            let conn: FramedConn<C> = match result {
145                Ok(stream) => framed(stream, C::default()),
146                Err(err) => {
147                    error!("Failed to accept connection: {err}");
148                    self.monitor.notify(ConnectionKind::AcceptFailed).await;
149                    continue;
150                }
151            };
152
153            let endpoint = match conn.peer_endpoint() {
154                Some(ep) => ep,
155                None => {
156                    self.monitor.notify(ConnectionKind::AcceptFailed).await;
157                    error!("Failed to get peer endpoint");
158                    continue;
159                }
160            };
161
162            self.monitor
163                .notify(ConnectionKind::Accepted(endpoint.clone()))
164                .await;
165            self.connection_slots.add();
166
167            let on_disconnect = {
168                let this = self.clone();
169                |res| async move {
170                    if let TaskResult::Completed(Err(err)) = res {
171                        debug!("Inbound connection dropped: {err}");
172                    }
173                    this.monitor
174                        .notify(ConnectionKind::Disconnected(endpoint))
175                        .await;
176                    this.connection_slots.remove().await;
177                }
178            };
179
180            let callback = callback.clone();
181            self.task_group.spawn(callback(conn), on_disconnect);
182        }
183    }
184
185    /// Create a listener for TCP/TLS.
186    async fn listen(&self, endpoint: &Endpoint) -> Result<StreamListener> {
187        match endpoint {
188            Endpoint::Tcp(..) => {
189                let listener = TcpListener::bind(endpoint, Default::default()).await?;
190                Ok(StreamListener::Tcp(listener))
191            }
192            Endpoint::Tls(..) => {
193                let tls_config = ServerTlsConfig {
194                    server_config: tls_server_config(&self.key_pair)?,
195                };
196                let tcp_listener = TcpListener::bind(endpoint, Default::default()).await?;
197                let listener = TlsListener::new(tcp_listener, tls_config);
198                Ok(StreamListener::Tls(Box::new(listener)))
199            }
200            _ => Err(Error::UnsupportedEndpoint(endpoint.to_string())),
201        }
202    }
203}
204
205// Auto-queue paths only live on the peer-plane Listener. The kademlia
206// lookup plane uses `start_with_callback` and handles each connection
207// inline (no ConnQueue / handshake pipeline).
208impl Listener<PeerNetMsgCodec> {
209    /// Create a new Listener with a ConnQueue (auto-queue mode).
210    pub fn new_with_queue(
211        key_pair: &KeyPair,
212        connection_slots: Arc<ConnectionSlots>,
213        conn_queue: Arc<ConnQueue>,
214        monitor: Arc<Monitor>,
215        ex: Executor,
216    ) -> Arc<Self> {
217        Arc::new(Self {
218            key_pair: key_pair.clone(),
219            connection_slots,
220            conn_queue: Some(conn_queue),
221            task_group: TaskGroup::with_executor(ex),
222            monitor,
223            _codec: PhantomData,
224        })
225    }
226
227    /// Start listening (auto-queue mode). Returns the resolved endpoint.
228    pub async fn start(self: &Arc<Self>, endpoint: Endpoint) -> Result<Endpoint> {
229        #[cfg(feature = "quic")]
230        if endpoint.is_quic() {
231            return self.start_quic(endpoint).await;
232        }
233
234        let listener = match self.listen(&endpoint).await {
235            Ok(l) => {
236                self.monitor
237                    .notify(ConnectionKind::Listening(endpoint.clone()))
238                    .await;
239                l
240            }
241            Err(err) => {
242                error!("Failed to listen on {endpoint}: {err}");
243                self.monitor
244                    .notify(ConnectionKind::ListenFailed(endpoint))
245                    .await;
246                return Err(err);
247            }
248        };
249
250        let resolved = listener.local_endpoint()?;
251        info!("Start listening on {resolved}");
252
253        self.task_group.spawn(
254            {
255                let this = self.clone();
256                async move { this.listen_loop(listener).await }
257            },
258            |_| async {},
259        );
260        Ok(resolved)
261    }
262
263    /// Accept loop (auto-queue mode).
264    async fn listen_loop(self: Arc<Self>, listener: StreamListener) {
265        let conn_queue = self
266            .conn_queue
267            .as_ref()
268            .expect("listen_loop requires ConnQueue")
269            .clone();
270
271        loop {
272            self.connection_slots.wait_for_slot().await;
273            let result = listener.accept().await;
274
275            let (conn, vpid) = match result {
276                Ok(stream) => {
277                    // Extract peer cert (TLS) before framing consumes the stream.
278                    let vpid = stream
279                        .peer_certificates()
280                        .as_deref()
281                        .and_then(peer_id_from_certs);
282                    let conn: FramedConn<PeerNetMsgCodec> = framed(stream, PeerNetMsgCodec::new());
283                    (conn, vpid)
284                }
285                Err(err) => {
286                    error!("Failed to accept connection: {err}");
287                    self.monitor.notify(ConnectionKind::AcceptFailed).await;
288                    continue;
289                }
290            };
291
292            let endpoint = match conn.peer_endpoint() {
293                Some(ep) => ep,
294                None => {
295                    self.monitor.notify(ConnectionKind::AcceptFailed).await;
296                    error!("Failed to get peer endpoint");
297                    continue;
298                }
299            };
300
301            self.monitor
302                .notify(ConnectionKind::Accepted(endpoint.clone()))
303                .await;
304            self.connection_slots.add();
305
306            let on_disconnect = {
307                let this = self.clone();
308                |res: TaskResult<Result<()>>| async move {
309                    if let TaskResult::Completed(Err(err)) = res {
310                        debug!("Inbound connection dropped: {err}");
311                    }
312                    this.monitor
313                        .notify(ConnectionKind::Disconnected(endpoint))
314                        .await;
315                    this.connection_slots.remove().await;
316                }
317            };
318
319            let cq = conn_queue.clone();
320            self.task_group.spawn(
321                async move {
322                    cq.handle(conn, ConnDirection::Inbound, vpid).await?;
323                    Ok(())
324                },
325                on_disconnect,
326            );
327        }
328    }
329
330    /// QUIC listener.
331    #[cfg(feature = "quic")]
332    async fn start_quic(self: &Arc<Self>, endpoint: Endpoint) -> Result<Endpoint> {
333        let rustls_config = tls_server_config(&self.key_pair)?;
334        let server_config = quic::ServerQuicConfig::from_rustls(rustls_config);
335
336        let quic_endpoint = match quic::QuicEndpoint::listen(&endpoint, server_config).await {
337            Ok(ep) => {
338                self.monitor
339                    .notify(ConnectionKind::Listening(endpoint.clone()))
340                    .await;
341                ep
342            }
343            Err(err) => {
344                error!("Failed to listen on {endpoint}: {err}");
345                self.monitor
346                    .notify(ConnectionKind::ListenFailed(endpoint))
347                    .await;
348                return Err(err.into());
349            }
350        };
351
352        let resolved: Endpoint = quic_endpoint.local_endpoint().map_err(Error::from)?;
353        info!("Start listening on {resolved}");
354
355        self.task_group.spawn(
356            {
357                let this = self.clone();
358                async move { this.listen_loop_quic(quic_endpoint).await }
359            },
360            |res: TaskResult<()>| async move {
361                debug!("QUIC listen loop ended: {res}");
362            },
363        );
364
365        Ok(resolved)
366    }
367
368    /// QUIC accept loop.
369    #[cfg(feature = "quic")]
370    async fn listen_loop_quic(self: Arc<Self>, quic_endpoint: quic::QuicEndpoint) {
371        loop {
372            self.connection_slots.wait_for_slot().await;
373
374            let quic_conn = match quic_endpoint.accept().await {
375                Ok(c) => c,
376                Err(err) => {
377                    error!("Failed to accept QUIC conn: {err}");
378                    self.monitor.notify(ConnectionKind::AcceptFailed).await;
379                    continue;
380                }
381            };
382
383            let peer_ep = match quic_conn.peer_endpoint() {
384                Ok(ep) => ep,
385                Err(err) => {
386                    error!("Failed to get peer endpoint: {err}");
387                    self.monitor.notify(ConnectionKind::AcceptFailed).await;
388                    continue;
389                }
390            };
391
392            self.monitor
393                .notify(ConnectionKind::Accepted(peer_ep.clone()))
394                .await;
395
396            let vpid = quic_conn
397                .peer_certificates()
398                .as_deref()
399                .and_then(peer_id_from_certs);
400
401            let stream = match quic_conn.accept_stream().await {
402                Ok(s) => s,
403                Err(err) => {
404                    error!("Failed to accept handshake stream: {err}");
405                    self.monitor.notify(ConnectionKind::AcceptFailed).await;
406                    continue;
407                }
408            };
409
410            let conn: FramedConn<PeerNetMsgCodec> = framed(stream, PeerNetMsgCodec::new());
411
412            self.connection_slots.add();
413
414            let on_disconnect = {
415                let this = self.clone();
416                |res: TaskResult<Result<()>>| async move {
417                    if let TaskResult::Completed(Err(err)) = res {
418                        debug!("Inbound QUIC conn dropped: {err}");
419                    }
420                    this.monitor
421                        .notify(ConnectionKind::Disconnected(peer_ep))
422                        .await;
423                    this.connection_slots.remove().await;
424                }
425            };
426
427            let conn_queue = self
428                .conn_queue
429                .as_ref()
430                .expect("QUIC listener requires ConnQueue")
431                .clone();
432            self.task_group.spawn(
433                async move {
434                    conn_queue
435                        .handle_quic(conn, quic_conn, ConnDirection::Inbound, vpid)
436                        .await?;
437                    Ok(())
438                },
439                on_disconnect,
440            );
441        }
442    }
443}