1pub mod builder;
2mod message_dispatcher;
3mod subscriptions;
4
5use std::{
6 sync::{
7 atomic::{AtomicBool, Ordering},
8 Arc,
9 },
10 time::Duration,
11};
12
13use async_channel::{Receiver, Sender};
14use log::{debug, error, info};
15use serde::{de::DeserializeOwned, Deserialize, Serialize};
16use serde_json::json;
17
18#[cfg(feature = "ws")]
19use karyon_net::ws::ClientWsConfig;
20#[cfg(all(feature = "ws", feature = "tls"))]
21use karyon_net::ws::ClientWssConfig;
22#[cfg(feature = "tls")]
23use karyon_net::{async_rustls::rustls, tls::ClientTlsConfig};
24
25#[cfg(feature = "tcp")]
26use crate::net::TcpConfig;
27
28use karyon_net::Conn;
29
30use karyon_core::{
31 async_util::{select, timeout, Either, TaskGroup, TaskResult},
32 util::random_32,
33};
34
35use crate::codec::ClonableJsonCodec;
36
37use crate::{
38 error::{Error, Result},
39 message::{self, SubscriptionID},
40 net::Endpoint,
41};
42
43pub use builder::ClientBuilder;
44pub use subscriptions::Subscription;
45
46use message_dispatcher::MessageDispatcher;
47use subscriptions::Subscriptions;
48
49type RequestID = u32;
50
51struct ClientConfig {
52 endpoint: Endpoint,
53 #[cfg(feature = "tcp")]
54 tcp_config: TcpConfig,
55 #[cfg(feature = "tls")]
56 tls_config: Option<(rustls::ClientConfig, String)>,
57 timeout: Option<u64>,
58 subscription_buffer_size: usize,
59}
60
61pub struct Client<C> {
63 disconnect: AtomicBool,
64 message_dispatcher: MessageDispatcher,
65 subscriptions: Arc<Subscriptions>,
66 send_chan: (Sender<serde_json::Value>, Receiver<serde_json::Value>),
67 task_group: TaskGroup,
68 config: ClientConfig,
69 codec: C,
70}
71
72#[derive(Serialize, Deserialize)]
73#[serde(untagged)]
74enum NewMsg {
75 Notification(message::Notification),
76 Response(message::Response),
77}
78
79impl<C> Client<C>
80where
81 C: ClonableJsonCodec + 'static,
82{
83 pub async fn call<T: Serialize + DeserializeOwned, V: DeserializeOwned>(
85 &self,
86 method: &str,
87 params: T,
88 ) -> Result<V> {
89 let response = self.send_request(method, params).await?;
90
91 match response.result {
92 Some(result) => Ok(serde_json::from_value::<V>(result)?),
93 None => Err(Error::InvalidMsg("Invalid response result".to_string())),
94 }
95 }
96
97 pub async fn subscribe<T: Serialize + DeserializeOwned>(
103 &self,
104 method: &str,
105 params: T,
106 ) -> Result<Arc<Subscription>> {
107 let response = self.send_request(method, params).await?;
108
109 let sub_id = match response.result {
110 Some(result) => serde_json::from_value::<SubscriptionID>(result)?,
111 None => return Err(Error::InvalidMsg("Invalid subscription id".to_string())),
112 };
113
114 let sub = self.subscriptions.subscribe(sub_id).await;
115
116 Ok(sub)
117 }
118
119 pub async fn unsubscribe(&self, method: &str, sub_id: SubscriptionID) -> Result<()> {
124 let _ = self.send_request(method, sub_id).await?;
125 self.subscriptions.unsubscribe(&sub_id).await;
126 Ok(())
127 }
128
129 pub async fn stop(&self) {
131 self.task_group.cancel().await;
132 }
133
134 async fn send_request<T: Serialize + DeserializeOwned>(
135 &self,
136 method: &str,
137 params: T,
138 ) -> Result<message::Response> {
139 let id: RequestID = random_32();
140 let request = message::Request {
141 jsonrpc: message::JSONRPC_VERSION.to_string(),
142 id: json!(id),
143 method: method.to_string(),
144 params: Some(json!(params)),
145 };
146
147 self.send(request).await?;
149
150 let rx = self.message_dispatcher.register(id).await;
152
153 let result = match self.config.timeout {
155 Some(t) => timeout(Duration::from_millis(t), rx.recv()).await?,
156 None => rx.recv().await,
157 };
158
159 let response = match result {
160 Ok(r) => r,
161 Err(err) => {
162 self.message_dispatcher.unregister(&id).await;
164 return Err(err.into());
165 }
166 };
167
168 if let Some(error) = response.error {
169 return Err(Error::SubscribeError(error.code, error.message));
170 }
171
172 if *response.id.as_ref().expect("Get response id") != id {
175 return Err(Error::InvalidMsg("Invalid response id".to_string()));
176 }
177
178 Ok(response)
179 }
180
181 async fn send(&self, req: message::Request) -> Result<()> {
182 if self.disconnect.load(Ordering::Relaxed) {
183 return Err(Error::ClientDisconnected);
184 }
185 let req = serde_json::to_value(req)?;
186 self.send_chan.0.send(req).await?;
187 Ok(())
188 }
189
190 async fn init(config: ClientConfig, codec: C) -> Result<Arc<Self>> {
192 let client = Arc::new(Client {
193 disconnect: AtomicBool::new(false),
194 subscriptions: Subscriptions::new(config.subscription_buffer_size),
195 send_chan: async_channel::bounded(10),
196 message_dispatcher: MessageDispatcher::new(),
197 task_group: TaskGroup::new(),
198 config,
199 codec,
200 });
201
202 let conn = client.connect().await?;
203 info!(
204 "Successfully connected to the RPC server: {}",
205 conn.peer_endpoint()?
206 );
207 client.start_background_loop(conn);
208 Ok(client)
209 }
210
211 async fn connect(self: &Arc<Self>) -> Result<Conn<serde_json::Value, Error>> {
212 let endpoint = self.config.endpoint.clone();
213 let codec = self.codec.clone();
214 let conn: Conn<serde_json::Value, Error> = match endpoint {
215 #[cfg(feature = "tcp")]
216 Endpoint::Tcp(..) => Box::new(
217 karyon_net::tcp::dial(&endpoint, self.config.tcp_config.clone(), codec).await?,
218 ),
219 #[cfg(feature = "tls")]
220 Endpoint::Tls(..) => match &self.config.tls_config {
221 Some((conf, dns_name)) => Box::new(
222 karyon_net::tls::dial(
223 &self.config.endpoint,
224 ClientTlsConfig {
225 dns_name: dns_name.to_string(),
226 client_config: conf.clone(),
227 tcp_config: self.config.tcp_config.clone(),
228 },
229 codec,
230 )
231 .await?,
232 ),
233 None => return Err(Error::TLSConfigRequired),
234 },
235 #[cfg(feature = "ws")]
236 Endpoint::Ws(..) => {
237 let config = ClientWsConfig {
238 tcp_config: self.config.tcp_config.clone(),
239 wss_config: None,
240 };
241 Box::new(karyon_net::ws::dial(&endpoint, config, codec).await?)
242 }
243 #[cfg(all(feature = "ws", feature = "tls"))]
244 Endpoint::Wss(..) => match &self.config.tls_config {
245 Some((conf, dns_name)) => Box::new(
246 karyon_net::ws::dial(
247 &endpoint,
248 ClientWsConfig {
249 tcp_config: self.config.tcp_config.clone(),
250 wss_config: Some(ClientWssConfig {
251 dns_name: dns_name.clone(),
252 client_config: conf.clone(),
253 }),
254 },
255 codec,
256 )
257 .await?,
258 ),
259 None => return Err(Error::TLSConfigRequired),
260 },
261 #[cfg(all(feature = "unix", target_family = "unix"))]
262 Endpoint::Unix(..) => {
263 Box::new(karyon_net::unix::dial(&endpoint, Default::default(), codec).await?)
264 }
265 _ => return Err(Error::UnsupportedProtocol(endpoint.to_string())),
266 };
267
268 Ok(conn)
269 }
270
271 fn start_background_loop(self: &Arc<Self>, conn: Conn<serde_json::Value, Error>) {
272 let on_complete = {
273 let this = self.clone();
274 |result: TaskResult<Result<()>>| async move {
275 if let TaskResult::Completed(Err(err)) = result {
276 error!("Client stopped: {err}");
277 }
278 this.disconnect.store(true, Ordering::Relaxed);
279 this.subscriptions.clear().await;
280 this.message_dispatcher.clear().await;
281 }
282 };
283
284 self.task_group.spawn(
286 {
287 let this = self.clone();
288 async move { this.background_loop(conn).await }
289 },
290 on_complete,
291 );
292 }
293
294 async fn background_loop(self: Arc<Self>, conn: Conn<serde_json::Value, Error>) -> Result<()> {
295 loop {
296 match select(self.send_chan.1.recv(), conn.recv()).await {
297 Either::Left(req) => {
298 conn.send(req?).await?;
299 }
300 Either::Right(msg) => match self.handle_msg(msg?).await {
301 Err(Error::SubscriptionBufferFull) => {
302 return Err(Error::SubscriptionBufferFull);
303 }
304 Err(err) => {
305 let endpoint = conn.peer_endpoint()?;
306 error!("Handle a new msg from the endpoint {endpoint} : {err}",);
307 }
308 Ok(_) => {}
309 },
310 }
311 }
312 }
313
314 async fn handle_msg(&self, msg: serde_json::Value) -> Result<()> {
315 match serde_json::from_value::<NewMsg>(msg.clone()) {
316 Ok(msg) => match msg {
317 NewMsg::Response(res) => {
318 debug!("<-- {res}");
319 self.message_dispatcher.dispatch(res).await
320 }
321 NewMsg::Notification(nt) => {
322 debug!("<-- {nt}");
323 self.subscriptions.notify(nt).await
324 }
325 },
326 Err(err) => {
327 error!("Receive unexpected msg {msg}: {err}");
328 Err(Error::InvalidMsg("Unexpected msg".to_string()))
329 }
330 }
331 }
332}