1use std::{future::Future, sync::Arc};
2
3use log::{error, trace, warn};
4
5use karyon_core::{
6 async_runtime::Executor,
7 async_util::{Backoff, TaskGroup, TaskResult},
8 crypto::KeyPair,
9};
10use karyon_net::{tcp, tls, Endpoint};
11
12use crate::{
13 codec::NetMsgCodec,
14 monitor::{ConnEvent, Monitor},
15 slots::ConnectionSlots,
16 tls_config::tls_client_config,
17 ConnRef, Error, PeerID, Result,
18};
19
20static DNS_NAME: &str = "karyontech.net";
21
22pub struct Connector {
24 key_pair: KeyPair,
26
27 task_group: TaskGroup,
29
30 connection_slots: Arc<ConnectionSlots>,
32
33 max_retries: usize,
36
37 enable_tls: bool,
39
40 monitor: Arc<Monitor>,
42}
43
44impl Connector {
45 pub fn new(
47 key_pair: &KeyPair,
48 max_retries: usize,
49 connection_slots: Arc<ConnectionSlots>,
50 enable_tls: bool,
51 monitor: Arc<Monitor>,
52 ex: Executor,
53 ) -> Arc<Self> {
54 Arc::new(Self {
55 key_pair: key_pair.clone(),
56 max_retries,
57 task_group: TaskGroup::with_executor(ex),
58 monitor,
59 connection_slots,
60 enable_tls,
61 })
62 }
63
64 pub async fn shutdown(&self) {
66 self.task_group.cancel().await;
67 }
68
69 pub async fn connect(&self, endpoint: &Endpoint, peer_id: &Option<PeerID>) -> Result<ConnRef> {
76 self.connection_slots.wait_for_slot().await;
77 self.connection_slots.add();
78
79 let mut retry = 0;
80 let backoff = Backoff::new(500, 2000);
81 while retry < self.max_retries {
82 match self.dial(endpoint, peer_id).await {
83 Ok(conn) => {
84 self.monitor
85 .notify(ConnEvent::Connected(endpoint.clone()))
86 .await;
87 return Ok(conn);
88 }
89 Err(err) => {
90 error!("Failed to establish a connection to {endpoint}: {err}");
91 }
92 }
93
94 self.monitor
95 .notify(ConnEvent::ConnectRetried(endpoint.clone()))
96 .await;
97
98 backoff.sleep().await;
99
100 warn!("try to reconnect {endpoint}");
101 retry += 1;
102 }
103
104 self.monitor
105 .notify(ConnEvent::ConnectFailed(endpoint.clone()))
106 .await;
107
108 self.connection_slots.remove().await;
109 Err(Error::Timeout)
110 }
111
112 pub async fn connect_with_cback<Fut>(
115 self: &Arc<Self>,
116 endpoint: &Endpoint,
117 peer_id: &Option<PeerID>,
118 callback: impl FnOnce(ConnRef) -> Fut + Send + 'static,
119 ) -> Result<()>
120 where
121 Fut: Future<Output = Result<()>> + Send + 'static,
122 {
123 let conn = self.connect(endpoint, peer_id).await?;
124
125 let endpoint = endpoint.clone();
126 let on_disconnect = {
127 let this = self.clone();
128 |res| async move {
129 if let TaskResult::Completed(Err(err)) = res {
130 trace!("Outbound connection dropped: {err}");
131 }
132 this.monitor
133 .notify(ConnEvent::Disconnected(endpoint.clone()))
134 .await;
135 this.connection_slots.remove().await;
136 }
137 };
138
139 self.task_group.spawn(callback(conn), on_disconnect);
140
141 Ok(())
142 }
143
144 async fn dial(&self, endpoint: &Endpoint, peer_id: &Option<PeerID>) -> Result<ConnRef> {
145 if self.enable_tls {
146 if !endpoint.is_tcp() && !endpoint.is_tls() {
147 return Err(Error::UnsupportedEndpoint(endpoint.to_string()));
148 }
149
150 let tls_config = tls::ClientTlsConfig {
151 tcp_config: Default::default(),
152 client_config: tls_client_config(&self.key_pair, peer_id.clone())?,
153 dns_name: DNS_NAME.to_string(),
154 };
155 let c = tls::dial(endpoint, tls_config, NetMsgCodec::new()).await?;
156 Ok(Box::new(c))
157 } else {
158 if !endpoint.is_tcp() {
159 return Err(Error::UnsupportedEndpoint(endpoint.to_string()));
160 }
161
162 let c = tcp::dial(endpoint, tcp::TcpConfig::default(), NetMsgCodec::new()).await?;
163 Ok(Box::new(c))
164 }
165 }
166}