1use std::{marker::PhantomData, sync::Arc};
2
3use log::{error, trace, warn};
4
5use karyon_core::{
6 async_runtime::Executor,
7 async_util::{Backoff, TaskGroup, TaskResult},
8 crypto::KeyPair,
9};
10use karyon_net::{codec::Codec, framed, tcp, ByteBuffer, ClientLayer, Endpoint, FramedConn};
11
12use crate::{
13 codec::PeerNetMsgCodec,
14 conn_queue::ConnQueue,
15 monitor::{ConnectionKind, Monitor},
16 peer::ConnDirection,
17 slots::ConnectionSlots,
18 tls_config::{peer_id_from_certs, tls_client_config},
19 Error, PeerID, Result,
20};
21
22#[cfg(feature = "quic")]
23use karyon_net::{quic, StreamMux};
24
25static DNS_NAME: &str = "karyontech.net";
26
27enum DialResult<C: Codec<ByteBuffer> + Default + Clone> {
32 Channel(FramedConn<C>, Option<PeerID>),
34 #[cfg(feature = "quic")]
36 Quic(FramedConn<C>, quic::QuicConn, Option<PeerID>),
37}
38
39pub struct Connector<C: Codec<ByteBuffer> + Default + Clone> {
45 key_pair: KeyPair,
46 task_group: TaskGroup,
47 connection_slots: Arc<ConnectionSlots>,
48 max_retries: usize,
49 conn_queue: Option<Arc<ConnQueue>>,
50 monitor: Arc<Monitor>,
51 _codec: PhantomData<C>,
52}
53
54impl<C> Connector<C>
55where
56 C: Codec<ByteBuffer, Error = karyon_net::Error> + Default + Clone + Send + Sync + 'static,
57{
58 pub fn new(
61 key_pair: &KeyPair,
62 max_retries: usize,
63 connection_slots: Arc<ConnectionSlots>,
64 monitor: Arc<Monitor>,
65 ex: Executor,
66 ) -> Arc<Self> {
67 Arc::new(Self {
68 key_pair: key_pair.clone(),
69 max_retries,
70 task_group: TaskGroup::with_executor(ex),
71 monitor,
72 connection_slots,
73 conn_queue: None,
74 _codec: PhantomData,
75 })
76 }
77
78 pub async fn shutdown(&self) {
79 self.task_group.cancel().await;
80 }
81
82 pub async fn connect(
84 &self,
85 endpoint: &Endpoint,
86 peer_id: &Option<PeerID>,
87 ) -> Result<FramedConn<C>> {
88 let result = self.connect_internal(endpoint, peer_id).await?;
89 match result {
90 DialResult::Channel(conn, _) => Ok(conn),
91 #[cfg(feature = "quic")]
92 DialResult::Quic(conn, _, _) => Ok(conn),
93 }
94 }
95
96 async fn connect_internal(
98 &self,
99 endpoint: &Endpoint,
100 peer_id: &Option<PeerID>,
101 ) -> Result<DialResult<C>> {
102 self.connection_slots.wait_for_slot().await;
103 self.connection_slots.add();
104
105 let mut retry = 0;
106 let backoff = Backoff::new(500, 2000);
107 while retry < self.max_retries {
108 match self.dial(endpoint, peer_id).await {
109 Ok(result) => {
110 self.monitor
111 .notify(ConnectionKind::Connected(endpoint.clone()))
112 .await;
113 return Ok(result);
114 }
115 Err(err) => {
116 error!("Failed to connect to {endpoint}: {err}");
117 }
118 }
119
120 self.monitor
121 .notify(ConnectionKind::ConnectRetried(endpoint.clone()))
122 .await;
123
124 backoff.sleep().await;
125 warn!("try to reconnect {endpoint}");
126 retry += 1;
127 }
128
129 self.monitor
130 .notify(ConnectionKind::ConnectFailed(endpoint.clone()))
131 .await;
132
133 self.connection_slots.remove().await;
134 Err(Error::Timeout)
135 }
136
137 async fn dial(&self, endpoint: &Endpoint, peer_id: &Option<PeerID>) -> Result<DialResult<C>> {
139 match endpoint {
140 Endpoint::Tcp(..) => {
141 let stream = tcp::connect(endpoint, Default::default()).await?;
142 let conn = framed(stream, C::default());
143 Ok(DialResult::Channel(conn, None))
144 }
145 Endpoint::Tls(..) => {
146 let tls_config = karyon_net::tls::ClientTlsConfig {
147 client_config: tls_client_config(&self.key_pair, peer_id.clone())?,
148 dns_name: DNS_NAME.to_string(),
149 };
150 let stream = tcp::connect(endpoint, Default::default()).await?;
151 let tls_layer = karyon_net::tls::TlsLayer::client(tls_config);
152 let tls_stream = ClientLayer::handshake(&tls_layer, stream).await?;
153 let vpid = tls_stream
155 .peer_certificates()
156 .as_deref()
157 .and_then(peer_id_from_certs);
158 let conn = framed(tls_stream, C::default());
159 Ok(DialResult::Channel(conn, vpid))
160 }
161 #[cfg(feature = "quic")]
162 Endpoint::Quic(..) => {
163 let rustls_config = tls_client_config(&self.key_pair, peer_id.clone())?;
164 let client_config = quic::ClientQuicConfig::from_rustls(rustls_config, DNS_NAME);
165 let quic_conn = quic::QuicEndpoint::dial(endpoint, client_config).await?;
166
167 let vpid = quic_conn
168 .peer_certificates()
169 .as_deref()
170 .and_then(peer_id_from_certs);
171
172 let stream = quic_conn.open_stream().await?;
174 let conn = framed(stream, C::default());
175
176 Ok(DialResult::Quic(conn, quic_conn, vpid))
177 }
178 _ => Err(Error::UnsupportedEndpoint(endpoint.to_string())),
179 }
180 }
181}
182
183impl Connector<PeerNetMsgCodec> {
186 pub fn new_with_queue(
188 key_pair: &KeyPair,
189 max_retries: usize,
190 connection_slots: Arc<ConnectionSlots>,
191 conn_queue: Arc<ConnQueue>,
192 monitor: Arc<Monitor>,
193 ex: Executor,
194 ) -> Arc<Self> {
195 Arc::new(Self {
196 key_pair: key_pair.clone(),
197 max_retries,
198 task_group: TaskGroup::with_executor(ex),
199 monitor,
200 connection_slots,
201 conn_queue: Some(conn_queue),
202 _codec: PhantomData,
203 })
204 }
205
206 pub async fn connect_and_queue(
210 self: &Arc<Self>,
211 endpoint: &Endpoint,
212 peer_id: &Option<PeerID>,
213 ) -> Result<()> {
214 let dial_result = self.connect_internal(endpoint, peer_id).await?;
215
216 let endpoint = endpoint.clone();
217 let on_disconnect = {
218 let this = self.clone();
219 |res: TaskResult<Result<()>>| async move {
220 if let TaskResult::Completed(Err(err)) = res {
221 trace!("Outbound connection dropped: {err}");
222 }
223 this.monitor
224 .notify(ConnectionKind::Disconnected(endpoint.clone()))
225 .await;
226 this.connection_slots.remove().await;
227 }
228 };
229
230 let conn_queue = self
231 .conn_queue
232 .as_ref()
233 .ok_or_else(|| {
234 Error::Config(
235 "connect_and_queue called on Connector built without ConnQueue".into(),
236 )
237 })?
238 .clone();
239 self.task_group.spawn(
240 async move {
241 match dial_result {
242 DialResult::Channel(conn, vpid) => {
243 conn_queue.handle(conn, ConnDirection::Outbound, vpid).await
244 }
245 #[cfg(feature = "quic")]
246 DialResult::Quic(conn, quic_conn, vpid) => {
247 conn_queue
248 .handle_quic(conn, quic_conn, ConnDirection::Outbound, vpid)
249 .await
250 }
251 }
252 },
253 on_disconnect,
254 );
255
256 Ok(())
257 }
258}