karyon_p2p/
connector.rs

1use std::{future::Future, sync::Arc};
2
3use log::{error, trace, warn};
4
5use karyon_core::{
6    async_runtime::Executor,
7    async_util::{Backoff, TaskGroup, TaskResult},
8    crypto::KeyPair,
9};
10use karyon_net::{tcp, tls, Endpoint};
11
12use crate::{
13    codec::NetMsgCodec,
14    monitor::{ConnEvent, Monitor},
15    slots::ConnectionSlots,
16    tls_config::tls_client_config,
17    ConnRef, Error, PeerID, Result,
18};
19
20static DNS_NAME: &str = "karyontech.net";
21
22/// Responsible for creating outbound connections with other peers.
23pub struct Connector {
24    /// Identity Key pair
25    key_pair: KeyPair,
26
27    /// Managing spawned tasks.
28    task_group: TaskGroup,
29
30    /// Manages available outbound slots.
31    connection_slots: Arc<ConnectionSlots>,
32
33    /// The maximum number of retries allowed before successfully
34    /// establishing a connection.
35    max_retries: usize,
36
37    /// Enables secure connection.
38    enable_tls: bool,
39
40    /// Responsible for network and system monitoring.
41    monitor: Arc<Monitor>,
42}
43
44impl Connector {
45    /// Creates a new Connector
46    pub fn new(
47        key_pair: &KeyPair,
48        max_retries: usize,
49        connection_slots: Arc<ConnectionSlots>,
50        enable_tls: bool,
51        monitor: Arc<Monitor>,
52        ex: Executor,
53    ) -> Arc<Self> {
54        Arc::new(Self {
55            key_pair: key_pair.clone(),
56            max_retries,
57            task_group: TaskGroup::with_executor(ex),
58            monitor,
59            connection_slots,
60            enable_tls,
61        })
62    }
63
64    /// Shuts down the connector
65    pub async fn shutdown(&self) {
66        self.task_group.cancel().await;
67    }
68
69    /// Establish a connection to the specified `endpoint`. If the connection
70    /// attempt fails, it performs a backoff and retries until the maximum allowed
71    /// number of retries is exceeded. On a successful connection, it returns a
72    /// `Conn` instance.
73    ///
74    /// This method will block until it finds an available slot.
75    pub async fn connect(&self, endpoint: &Endpoint, peer_id: &Option<PeerID>) -> Result<ConnRef> {
76        self.connection_slots.wait_for_slot().await;
77        self.connection_slots.add();
78
79        let mut retry = 0;
80        let backoff = Backoff::new(500, 2000);
81        while retry < self.max_retries {
82            match self.dial(endpoint, peer_id).await {
83                Ok(conn) => {
84                    self.monitor
85                        .notify(ConnEvent::Connected(endpoint.clone()))
86                        .await;
87                    return Ok(conn);
88                }
89                Err(err) => {
90                    error!("Failed to establish a connection to {endpoint}: {err}");
91                }
92            }
93
94            self.monitor
95                .notify(ConnEvent::ConnectRetried(endpoint.clone()))
96                .await;
97
98            backoff.sleep().await;
99
100            warn!("try to reconnect {endpoint}");
101            retry += 1;
102        }
103
104        self.monitor
105            .notify(ConnEvent::ConnectFailed(endpoint.clone()))
106            .await;
107
108        self.connection_slots.remove().await;
109        Err(Error::Timeout)
110    }
111
112    /// Establish a connection to the given `endpoint`. For each new connection,
113    /// it invokes the provided `callback`, and pass the connection to the callback.
114    pub async fn connect_with_cback<Fut>(
115        self: &Arc<Self>,
116        endpoint: &Endpoint,
117        peer_id: &Option<PeerID>,
118        callback: impl FnOnce(ConnRef) -> Fut + Send + 'static,
119    ) -> Result<()>
120    where
121        Fut: Future<Output = Result<()>> + Send + 'static,
122    {
123        let conn = self.connect(endpoint, peer_id).await?;
124
125        let endpoint = endpoint.clone();
126        let on_disconnect = {
127            let this = self.clone();
128            |res| async move {
129                if let TaskResult::Completed(Err(err)) = res {
130                    trace!("Outbound connection dropped: {err}");
131                }
132                this.monitor
133                    .notify(ConnEvent::Disconnected(endpoint.clone()))
134                    .await;
135                this.connection_slots.remove().await;
136            }
137        };
138
139        self.task_group.spawn(callback(conn), on_disconnect);
140
141        Ok(())
142    }
143
144    async fn dial(&self, endpoint: &Endpoint, peer_id: &Option<PeerID>) -> Result<ConnRef> {
145        if self.enable_tls {
146            if !endpoint.is_tcp() && !endpoint.is_tls() {
147                return Err(Error::UnsupportedEndpoint(endpoint.to_string()));
148            }
149
150            let tls_config = tls::ClientTlsConfig {
151                tcp_config: Default::default(),
152                client_config: tls_client_config(&self.key_pair, peer_id.clone())?,
153                dns_name: DNS_NAME.to_string(),
154            };
155            let c = tls::dial(endpoint, tls_config, NetMsgCodec::new()).await?;
156            Ok(Box::new(c))
157        } else {
158            if !endpoint.is_tcp() {
159                return Err(Error::UnsupportedEndpoint(endpoint.to_string()));
160            }
161
162            let c = tcp::dial(endpoint, tcp::TcpConfig::default(), NetMsgCodec::new()).await?;
163            Ok(Box::new(c))
164        }
165    }
166}