karyon_jsonrpc/client/
mod.rs

1pub mod builder;
2mod message_dispatcher;
3mod subscriptions;
4
5use std::{
6    sync::{
7        atomic::{AtomicBool, Ordering},
8        Arc,
9    },
10    time::Duration,
11};
12
13use async_channel::{Receiver, Sender};
14use log::{debug, error, info};
15use serde::{de::DeserializeOwned, Deserialize, Serialize};
16use serde_json::json;
17
18#[cfg(feature = "ws")]
19use karyon_net::ws::ClientWsConfig;
20#[cfg(all(feature = "ws", feature = "tls"))]
21use karyon_net::ws::ClientWssConfig;
22#[cfg(feature = "tls")]
23use karyon_net::{async_rustls::rustls, tls::ClientTlsConfig};
24
25#[cfg(feature = "tcp")]
26use crate::net::TcpConfig;
27
28use karyon_net::Conn;
29
30use karyon_core::{
31    async_util::{select, timeout, Either, TaskGroup, TaskResult},
32    util::random_32,
33};
34
35use crate::codec::ClonableJsonCodec;
36
37use crate::{
38    error::{Error, Result},
39    message::{self, SubscriptionID},
40    net::Endpoint,
41};
42
43pub use builder::ClientBuilder;
44pub use subscriptions::Subscription;
45
46use message_dispatcher::MessageDispatcher;
47use subscriptions::Subscriptions;
48
49type RequestID = u32;
50
51struct ClientConfig {
52    endpoint: Endpoint,
53    #[cfg(feature = "tcp")]
54    tcp_config: TcpConfig,
55    #[cfg(feature = "tls")]
56    tls_config: Option<(rustls::ClientConfig, String)>,
57    timeout: Option<u64>,
58    subscription_buffer_size: usize,
59}
60
61/// Represents an RPC client
62pub struct Client<C> {
63    disconnect: AtomicBool,
64    message_dispatcher: MessageDispatcher,
65    subscriptions: Arc<Subscriptions>,
66    send_chan: (Sender<serde_json::Value>, Receiver<serde_json::Value>),
67    task_group: TaskGroup,
68    config: ClientConfig,
69    codec: C,
70}
71
72#[derive(Serialize, Deserialize)]
73#[serde(untagged)]
74enum NewMsg {
75    Notification(message::Notification),
76    Response(message::Response),
77}
78
79impl<C> Client<C>
80where
81    C: ClonableJsonCodec + 'static,
82{
83    /// Calls the provided method, waits for the response, and returns the result.
84    pub async fn call<T: Serialize + DeserializeOwned, V: DeserializeOwned>(
85        &self,
86        method: &str,
87        params: T,
88    ) -> Result<V> {
89        let response = self.send_request(method, params).await?;
90
91        match response.result {
92            Some(result) => Ok(serde_json::from_value::<V>(result)?),
93            None => Err(Error::InvalidMsg("Invalid response result".to_string())),
94        }
95    }
96
97    /// Subscribes to the provided method, waits for the response, and returns the result.
98    ///
99    /// This function sends a subscription request to the specified method
100    /// with the given parameters. It waits for the response and returns a
101    /// `Subscription`.
102    pub async fn subscribe<T: Serialize + DeserializeOwned>(
103        &self,
104        method: &str,
105        params: T,
106    ) -> Result<Arc<Subscription>> {
107        let response = self.send_request(method, params).await?;
108
109        let sub_id = match response.result {
110            Some(result) => serde_json::from_value::<SubscriptionID>(result)?,
111            None => return Err(Error::InvalidMsg("Invalid subscription id".to_string())),
112        };
113
114        let sub = self.subscriptions.subscribe(sub_id).await;
115
116        Ok(sub)
117    }
118
119    /// Unsubscribes from the provided method, waits for the response, and returns the result.
120    ///
121    /// This function sends an unsubscription request for the specified method
122    /// and subscription ID. It waits for the response to confirm the unsubscription.
123    pub async fn unsubscribe(&self, method: &str, sub_id: SubscriptionID) -> Result<()> {
124        let _ = self.send_request(method, sub_id).await?;
125        self.subscriptions.unsubscribe(&sub_id).await;
126        Ok(())
127    }
128
129    /// Disconnect the client
130    pub async fn stop(&self) {
131        self.task_group.cancel().await;
132    }
133
134    async fn send_request<T: Serialize + DeserializeOwned>(
135        &self,
136        method: &str,
137        params: T,
138    ) -> Result<message::Response> {
139        let id: RequestID = random_32();
140        let request = message::Request {
141            jsonrpc: message::JSONRPC_VERSION.to_string(),
142            id: json!(id),
143            method: method.to_string(),
144            params: Some(json!(params)),
145        };
146
147        // Send the request
148        self.send(request).await?;
149
150        // Register a new request
151        let rx = self.message_dispatcher.register(id).await;
152
153        // Wait for the message dispatcher to send the response
154        let result = match self.config.timeout {
155            Some(t) => timeout(Duration::from_millis(t), rx.recv()).await?,
156            None => rx.recv().await,
157        };
158
159        let response = match result {
160            Ok(r) => r,
161            Err(err) => {
162                // Unregister the request if an error occurs
163                self.message_dispatcher.unregister(&id).await;
164                return Err(err.into());
165            }
166        };
167
168        if let Some(error) = response.error {
169            return Err(Error::SubscribeError(error.code, error.message));
170        }
171
172        // It should be OK to unwrap here, as the message dispatcher checks
173        // for the response id.
174        if *response.id.as_ref().expect("Get response id") != id {
175            return Err(Error::InvalidMsg("Invalid response id".to_string()));
176        }
177
178        Ok(response)
179    }
180
181    async fn send(&self, req: message::Request) -> Result<()> {
182        if self.disconnect.load(Ordering::Relaxed) {
183            return Err(Error::ClientDisconnected);
184        }
185        let req = serde_json::to_value(req)?;
186        self.send_chan.0.send(req).await?;
187        Ok(())
188    }
189
190    /// Initializes a new [`Client`] from the provided [`ClientConfig`].
191    async fn init(config: ClientConfig, codec: C) -> Result<Arc<Self>> {
192        let client = Arc::new(Client {
193            disconnect: AtomicBool::new(false),
194            subscriptions: Subscriptions::new(config.subscription_buffer_size),
195            send_chan: async_channel::bounded(10),
196            message_dispatcher: MessageDispatcher::new(),
197            task_group: TaskGroup::new(),
198            config,
199            codec,
200        });
201
202        let conn = client.connect().await?;
203        info!(
204            "Successfully connected to the RPC server: {}",
205            conn.peer_endpoint()?
206        );
207        client.start_background_loop(conn);
208        Ok(client)
209    }
210
211    async fn connect(self: &Arc<Self>) -> Result<Conn<serde_json::Value, Error>> {
212        let endpoint = self.config.endpoint.clone();
213        let codec = self.codec.clone();
214        let conn: Conn<serde_json::Value, Error> = match endpoint {
215            #[cfg(feature = "tcp")]
216            Endpoint::Tcp(..) => Box::new(
217                karyon_net::tcp::dial(&endpoint, self.config.tcp_config.clone(), codec).await?,
218            ),
219            #[cfg(feature = "tls")]
220            Endpoint::Tls(..) => match &self.config.tls_config {
221                Some((conf, dns_name)) => Box::new(
222                    karyon_net::tls::dial(
223                        &self.config.endpoint,
224                        ClientTlsConfig {
225                            dns_name: dns_name.to_string(),
226                            client_config: conf.clone(),
227                            tcp_config: self.config.tcp_config.clone(),
228                        },
229                        codec,
230                    )
231                    .await?,
232                ),
233                None => return Err(Error::TLSConfigRequired),
234            },
235            #[cfg(feature = "ws")]
236            Endpoint::Ws(..) => {
237                let config = ClientWsConfig {
238                    tcp_config: self.config.tcp_config.clone(),
239                    wss_config: None,
240                };
241                Box::new(karyon_net::ws::dial(&endpoint, config, codec).await?)
242            }
243            #[cfg(all(feature = "ws", feature = "tls"))]
244            Endpoint::Wss(..) => match &self.config.tls_config {
245                Some((conf, dns_name)) => Box::new(
246                    karyon_net::ws::dial(
247                        &endpoint,
248                        ClientWsConfig {
249                            tcp_config: self.config.tcp_config.clone(),
250                            wss_config: Some(ClientWssConfig {
251                                dns_name: dns_name.clone(),
252                                client_config: conf.clone(),
253                            }),
254                        },
255                        codec,
256                    )
257                    .await?,
258                ),
259                None => return Err(Error::TLSConfigRequired),
260            },
261            #[cfg(all(feature = "unix", target_family = "unix"))]
262            Endpoint::Unix(..) => {
263                Box::new(karyon_net::unix::dial(&endpoint, Default::default(), codec).await?)
264            }
265            _ => return Err(Error::UnsupportedProtocol(endpoint.to_string())),
266        };
267
268        Ok(conn)
269    }
270
271    fn start_background_loop(self: &Arc<Self>, conn: Conn<serde_json::Value, Error>) {
272        let on_complete = {
273            let this = self.clone();
274            |result: TaskResult<Result<()>>| async move {
275                if let TaskResult::Completed(Err(err)) = result {
276                    error!("Client stopped: {err}");
277                }
278                this.disconnect.store(true, Ordering::Relaxed);
279                this.subscriptions.clear().await;
280                this.message_dispatcher.clear().await;
281            }
282        };
283
284        // Spawn a new task
285        self.task_group.spawn(
286            {
287                let this = self.clone();
288                async move { this.background_loop(conn).await }
289            },
290            on_complete,
291        );
292    }
293
294    async fn background_loop(self: Arc<Self>, conn: Conn<serde_json::Value, Error>) -> Result<()> {
295        loop {
296            match select(self.send_chan.1.recv(), conn.recv()).await {
297                Either::Left(req) => {
298                    conn.send(req?).await?;
299                }
300                Either::Right(msg) => match self.handle_msg(msg?).await {
301                    Err(Error::SubscriptionBufferFull) => {
302                        return Err(Error::SubscriptionBufferFull);
303                    }
304                    Err(err) => {
305                        let endpoint = conn.peer_endpoint()?;
306                        error!("Handle a new msg from the endpoint {endpoint} : {err}",);
307                    }
308                    Ok(_) => {}
309                },
310            }
311        }
312    }
313
314    async fn handle_msg(&self, msg: serde_json::Value) -> Result<()> {
315        match serde_json::from_value::<NewMsg>(msg.clone()) {
316            Ok(msg) => match msg {
317                NewMsg::Response(res) => {
318                    debug!("<-- {res}");
319                    self.message_dispatcher.dispatch(res).await
320                }
321                NewMsg::Notification(nt) => {
322                    debug!("<-- {nt}");
323                    self.subscriptions.notify(nt).await
324                }
325            },
326            Err(err) => {
327                error!("Receive unexpected msg {msg}: {err}");
328                Err(Error::InvalidMsg("Unexpected msg".to_string()))
329            }
330        }
331    }
332}