Skip to main content

karyon_p2p/
connector.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use log::{error, trace, warn};
4
5use karyon_core::{
6    async_runtime::Executor,
7    async_util::{Backoff, TaskGroup, TaskResult},
8    crypto::KeyPair,
9};
10use karyon_net::{codec::Codec, framed, tcp, ByteBuffer, ClientLayer, Endpoint, FramedConn};
11
12use crate::{
13    codec::PeerNetMsgCodec,
14    conn_queue::ConnQueue,
15    monitor::{ConnectionKind, Monitor},
16    peer::ConnDirection,
17    slots::ConnectionSlots,
18    tls_config::{peer_id_from_certs, tls_client_config},
19    Error, PeerID, Result,
20};
21
22#[cfg(feature = "quic")]
23use karyon_net::{quic, StreamMux};
24
25static DNS_NAME: &str = "karyontech.net";
26
27/// Internal dial result. Generic over the codec used to frame the
28/// resulting connection. Carries the optional cert-derived PeerID so
29/// the caller can stamp it on the QueuedConn for the application
30/// handshake to enforce.
31enum DialResult<C: Codec<ByteBuffer> + Default + Clone> {
32    /// TCP or TLS: a framed connection.
33    Channel(FramedConn<C>, Option<PeerID>),
34    /// QUIC: handshake stream + full QUIC connection.
35    #[cfg(feature = "quic")]
36    Quic(FramedConn<C>, quic::QuicConn, Option<PeerID>),
37}
38
39/// Creates outbound connections with other peers. Generic over the
40/// codec applied to the resulting framed stream so the same machinery
41/// (TLS, retries, slots, monitor events) serves both the peer
42/// data-plane (`PeerNetMsgCodec`) and the kademlia lookup-plane
43/// (`KadNetMsgCodec`).
44pub struct Connector<C: Codec<ByteBuffer> + Default + Clone> {
45    key_pair: KeyPair,
46    task_group: TaskGroup,
47    connection_slots: Arc<ConnectionSlots>,
48    max_retries: usize,
49    conn_queue: Option<Arc<ConnQueue>>,
50    monitor: Arc<Monitor>,
51    _codec: PhantomData<C>,
52}
53
54impl<C> Connector<C>
55where
56    C: Codec<ByteBuffer, Error = karyon_net::Error> + Default + Clone + Send + Sync + 'static,
57{
58    /// Create a new Connector without a ConnQueue (used for plain
59    /// `connect` paths like Kademlia lookup).
60    pub fn new(
61        key_pair: &KeyPair,
62        max_retries: usize,
63        connection_slots: Arc<ConnectionSlots>,
64        monitor: Arc<Monitor>,
65        ex: Executor,
66    ) -> Arc<Self> {
67        Arc::new(Self {
68            key_pair: key_pair.clone(),
69            max_retries,
70            task_group: TaskGroup::with_executor(ex),
71            monitor,
72            connection_slots,
73            conn_queue: None,
74            _codec: PhantomData,
75        })
76    }
77
78    pub async fn shutdown(&self) {
79        self.task_group.cancel().await;
80    }
81
82    /// Connect and return the framed connection.
83    pub async fn connect(
84        &self,
85        endpoint: &Endpoint,
86        peer_id: &Option<PeerID>,
87    ) -> Result<FramedConn<C>> {
88        let result = self.connect_internal(endpoint, peer_id).await?;
89        match result {
90            DialResult::Channel(conn, _) => Ok(conn),
91            #[cfg(feature = "quic")]
92            DialResult::Quic(conn, _, _) => Ok(conn),
93        }
94    }
95
96    /// Dial with retries.
97    async fn connect_internal(
98        &self,
99        endpoint: &Endpoint,
100        peer_id: &Option<PeerID>,
101    ) -> Result<DialResult<C>> {
102        self.connection_slots.wait_for_slot().await;
103        self.connection_slots.add();
104
105        let mut retry = 0;
106        let backoff = Backoff::new(500, 2000);
107        while retry < self.max_retries {
108            match self.dial(endpoint, peer_id).await {
109                Ok(result) => {
110                    self.monitor
111                        .notify(ConnectionKind::Connected(endpoint.clone()))
112                        .await;
113                    return Ok(result);
114                }
115                Err(err) => {
116                    error!("Failed to connect to {endpoint}: {err}");
117                }
118            }
119
120            self.monitor
121                .notify(ConnectionKind::ConnectRetried(endpoint.clone()))
122                .await;
123
124            backoff.sleep().await;
125            warn!("try to reconnect {endpoint}");
126            retry += 1;
127        }
128
129        self.monitor
130            .notify(ConnectionKind::ConnectFailed(endpoint.clone()))
131            .await;
132
133        self.connection_slots.remove().await;
134        Err(Error::Timeout)
135    }
136
137    /// Dial, selecting transport based on endpoint type.
138    async fn dial(&self, endpoint: &Endpoint, peer_id: &Option<PeerID>) -> Result<DialResult<C>> {
139        match endpoint {
140            Endpoint::Tcp(..) => {
141                let stream = tcp::connect(endpoint, Default::default()).await?;
142                let conn = framed(stream, C::default());
143                Ok(DialResult::Channel(conn, None))
144            }
145            Endpoint::Tls(..) => {
146                let tls_config = karyon_net::tls::ClientTlsConfig {
147                    client_config: tls_client_config(&self.key_pair, peer_id.clone())?,
148                    dns_name: DNS_NAME.to_string(),
149                };
150                let stream = tcp::connect(endpoint, Default::default()).await?;
151                let tls_layer = karyon_net::tls::TlsLayer::client(tls_config);
152                let tls_stream = ClientLayer::handshake(&tls_layer, stream).await?;
153                // Extract the peer cert before framing consumes the stream.
154                let vpid = tls_stream
155                    .peer_certificates()
156                    .as_deref()
157                    .and_then(peer_id_from_certs);
158                let conn = framed(tls_stream, C::default());
159                Ok(DialResult::Channel(conn, vpid))
160            }
161            #[cfg(feature = "quic")]
162            Endpoint::Quic(..) => {
163                let rustls_config = tls_client_config(&self.key_pair, peer_id.clone())?;
164                let client_config = quic::ClientQuicConfig::from_rustls(rustls_config, DNS_NAME);
165                let quic_conn = quic::QuicEndpoint::dial(endpoint, client_config).await?;
166
167                let vpid = quic_conn
168                    .peer_certificates()
169                    .as_deref()
170                    .and_then(peer_id_from_certs);
171
172                // First stream for handshake via StreamMux.
173                let stream = quic_conn.open_stream().await?;
174                let conn = framed(stream, C::default());
175
176                Ok(DialResult::Quic(conn, quic_conn, vpid))
177            }
178            _ => Err(Error::UnsupportedEndpoint(endpoint.to_string())),
179        }
180    }
181}
182
183// `connect_and_queue` only lives on the peer-plane Connector since the
184// ConnQueue is part of the data-plane handshake pipeline.
185impl Connector<PeerNetMsgCodec> {
186    /// Create a new Connector with a ConnQueue.
187    pub fn new_with_queue(
188        key_pair: &KeyPair,
189        max_retries: usize,
190        connection_slots: Arc<ConnectionSlots>,
191        conn_queue: Arc<ConnQueue>,
192        monitor: Arc<Monitor>,
193        ex: Executor,
194    ) -> Arc<Self> {
195        Arc::new(Self {
196            key_pair: key_pair.clone(),
197            max_retries,
198            task_group: TaskGroup::with_executor(ex),
199            monitor,
200            connection_slots,
201            conn_queue: Some(conn_queue),
202            _codec: PhantomData,
203        })
204    }
205
206    /// Connect, queue for handshake, and run until the peer
207    /// disconnects. Disconnect/handshake-failure is observable via
208    /// `PeerPool::register_peer_events`, not via this call.
209    pub async fn connect_and_queue(
210        self: &Arc<Self>,
211        endpoint: &Endpoint,
212        peer_id: &Option<PeerID>,
213    ) -> Result<()> {
214        let dial_result = self.connect_internal(endpoint, peer_id).await?;
215
216        let endpoint = endpoint.clone();
217        let on_disconnect = {
218            let this = self.clone();
219            |res: TaskResult<Result<()>>| async move {
220                if let TaskResult::Completed(Err(err)) = res {
221                    trace!("Outbound connection dropped: {err}");
222                }
223                this.monitor
224                    .notify(ConnectionKind::Disconnected(endpoint.clone()))
225                    .await;
226                this.connection_slots.remove().await;
227            }
228        };
229
230        let conn_queue = self
231            .conn_queue
232            .as_ref()
233            .ok_or_else(|| {
234                Error::Config(
235                    "connect_and_queue called on Connector built without ConnQueue".into(),
236                )
237            })?
238            .clone();
239        self.task_group.spawn(
240            async move {
241                match dial_result {
242                    DialResult::Channel(conn, vpid) => {
243                        conn_queue.handle(conn, ConnDirection::Outbound, vpid).await
244                    }
245                    #[cfg(feature = "quic")]
246                    DialResult::Quic(conn, quic_conn, vpid) => {
247                        conn_queue
248                            .handle_quic(conn, quic_conn, ConnDirection::Outbound, vpid)
249                            .await
250                    }
251                }
252            },
253            on_disconnect,
254        );
255
256        Ok(())
257    }
258}