Skip to main content

karyon_jsonrpc/client/
multiplexed.rs

1//! Multiplexed client mode (TCP/TLS/Unix and WS/WSS).
2//! A single message connection is shared across all calls and
3//! subscriptions. A background task splits reader from writer.
4
5use std::sync::{atomic::Ordering, Arc};
6
7use async_channel::Receiver;
8use log::{debug, error, info};
9use serde::{de::DeserializeOwned, Deserialize, Serialize};
10use serde_json::json;
11
12use karyon_core::{
13    async_util::{select, timeout, Either, TaskResult},
14    util::random_32,
15};
16
17use karyon_net::{framed, Endpoint, FramedConn, MessageRx, MessageTx};
18
19#[cfg(feature = "ws")]
20use karyon_net::{layers::ws::WsConn, ClientLayer};
21
22use crate::{
23    client::{
24        message_dispatcher::MessageDispatcher, subscriptions::Subscriptions, Client, ClientBackend,
25        ClientConfig, RequestID, WsCodec,
26    },
27    codec::JsonRpcCodec,
28    error::{Error, Result},
29    message,
30};
31
32#[cfg(feature = "ws")]
33use crate::codec::JsonRpcWsCodec;
34
35/// Capacity of the outbound queue feeding the writer task. Small
36/// because each in-flight call already has its own response channel,
37/// so this only buffers bursts between caller and writer.
38const OUTBOUND_BUFFER_SIZE: usize = 10;
39
40#[derive(Serialize, Deserialize)]
41#[serde(untagged)]
42enum NewMsg {
43    Notification(message::Notification),
44    Response(message::Response),
45}
46
47/// Build the multiplexed backend by connecting over a byte stream
48/// (TCP, TLS, or Unix) and framing it.
49pub(super) async fn build_byte_backend<B>(
50    config: &ClientConfig,
51    codec: B,
52) -> Result<(ClientBackend, FramedConn<B>)>
53where
54    B: JsonRpcCodec,
55{
56    let conn = connect_byte(config, codec).await?;
57    let peer = conn.peer_endpoint();
58    info!("Successfully connected to the RPC server: {peer:?}");
59
60    let backend = build_backend_state(config);
61    Ok((backend, conn))
62}
63
64/// Build the multiplexed backend by connecting over a WebSocket.
65#[cfg(feature = "ws")]
66pub(super) async fn build_ws_backend<W>(
67    config: &ClientConfig,
68    codec: W,
69) -> Result<(ClientBackend, WsConn<W>)>
70where
71    W: JsonRpcWsCodec,
72{
73    let conn = connect_ws(config, codec).await?;
74    let peer = conn.peer_endpoint();
75    info!("Successfully connected to the RPC server: {peer:?}");
76
77    let backend = build_backend_state(config);
78    Ok((backend, conn))
79}
80
81fn build_backend_state(config: &ClientConfig) -> ClientBackend {
82    ClientBackend::Multiplexed {
83        message_dispatcher: MessageDispatcher::new(),
84        subscriptions: Subscriptions::new(config.subscription_buffer_size),
85        send_chan: async_channel::bounded(OUTBOUND_BUFFER_SIZE),
86    }
87}
88
89/// Spawn the reader/writer task that drives the multiplexed wire.
90pub(super) fn start_io_loop<B, W, R, Wr>(client: &Arc<Client<B, W>>, reader: R, writer: Wr)
91where
92    B: JsonRpcCodec,
93    W: WsCodec,
94    R: MessageRx<Message = serde_json::Value> + Send + 'static,
95    Wr: MessageTx<Message = serde_json::Value> + Send + 'static,
96{
97    let this = client.clone();
98    client.task_group.spawn(
99        background_loop(client.clone(), reader, writer),
100        move |result: TaskResult<Result<()>>| async move {
101            if let TaskResult::Completed(Err(err)) = result {
102                error!("Client stopped: {err}");
103            }
104            this.disconnect.store(true, Ordering::Relaxed);
105            if let ClientBackend::Multiplexed {
106                subscriptions,
107                message_dispatcher,
108                ..
109            } = &this.backend
110            {
111                subscriptions.clear().await;
112                message_dispatcher.clear().await;
113            }
114        },
115    );
116}
117
118async fn background_loop<B, W, R, Wr>(
119    client: Arc<Client<B, W>>,
120    reader: R,
121    writer: Wr,
122) -> Result<()>
123where
124    B: JsonRpcCodec,
125    W: WsCodec,
126    R: MessageRx<Message = serde_json::Value> + Send,
127    Wr: MessageTx<Message = serde_json::Value> + Send,
128{
129    let ClientBackend::Multiplexed {
130        message_dispatcher,
131        subscriptions,
132        send_chan,
133    } = &client.backend
134    else {
135        return Err(Error::InvalidState("not in multiplexed mode".into()));
136    };
137
138    run_io_loop(
139        reader,
140        writer,
141        send_chan.1.clone(),
142        message_dispatcher,
143        subscriptions,
144    )
145    .await
146}
147
148async fn run_io_loop<R, Wr>(
149    mut reader: R,
150    mut writer: Wr,
151    outbound: Receiver<serde_json::Value>,
152    message_dispatcher: &MessageDispatcher,
153    subscriptions: &Subscriptions,
154) -> Result<()>
155where
156    R: MessageRx<Message = serde_json::Value> + Send,
157    Wr: MessageTx<Message = serde_json::Value> + Send,
158{
159    loop {
160        match select(outbound.recv(), reader.recv_msg()).await {
161            Either::Left(req) => {
162                writer.send_msg(req?).await?;
163            }
164            Either::Right(msg) => {
165                match handle_mux_msg(message_dispatcher, subscriptions, msg?).await {
166                    Err(Error::SubscriptionBufferFull) => {
167                        return Err(Error::SubscriptionBufferFull);
168                    }
169                    Err(err) => {
170                        let ep = reader.peer_endpoint();
171                        error!("Handle msg from {ep:?}: {err}");
172                    }
173                    Ok(_) => {}
174                }
175            }
176        }
177    }
178}
179
180async fn handle_mux_msg(
181    message_dispatcher: &MessageDispatcher,
182    subscriptions: &Subscriptions,
183    msg: serde_json::Value,
184) -> Result<()> {
185    match serde_json::from_value::<NewMsg>(msg.clone()) {
186        Ok(NewMsg::Response(res)) => {
187            debug!("<-- {res}");
188            message_dispatcher.dispatch(res).await
189        }
190        Ok(NewMsg::Notification(nt)) => {
191            debug!("<-- {nt}");
192            subscriptions.notify(nt).await
193        }
194        Err(err) => {
195            error!("Receive unexpected msg {msg}: {err}");
196            Err(Error::InvalidMsg("Unexpected msg".to_string()))
197        }
198    }
199}
200
201async fn connect_byte<B>(config: &ClientConfig, codec: B) -> Result<FramedConn<B>>
202where
203    B: JsonRpcCodec,
204{
205    let endpoint = config.endpoint.clone();
206
207    match &endpoint {
208        #[cfg(feature = "tcp")]
209        Endpoint::Tcp(..) => {
210            let stream = karyon_net::tcp::connect(&endpoint, config.tcp_config.clone()).await?;
211            Ok(framed(stream, codec))
212        }
213        #[cfg(feature = "tls")]
214        Endpoint::Tls(..) => {
215            let stream = karyon_net::tcp::connect(&endpoint, config.tcp_config.clone()).await?;
216            let tls_config = config.tls_config.as_ref().ok_or(Error::TLSConfigRequired)?;
217            let tls_layer = karyon_net::tls::TlsLayer::client(tls_config.clone());
218            let tls_stream = karyon_net::ClientLayer::handshake(&tls_layer, stream).await?;
219            Ok(framed(tls_stream, codec))
220        }
221        #[cfg(all(feature = "unix", target_family = "unix"))]
222        Endpoint::Unix(..) => {
223            let stream = karyon_net::unix::connect(&endpoint).await?;
224            Ok(framed(stream, codec))
225        }
226        _ => Err(Error::UnsupportedProtocol(endpoint.to_string())),
227    }
228}
229
230#[cfg(feature = "ws")]
231async fn connect_ws<W>(config: &ClientConfig, codec: W) -> Result<WsConn<W>>
232where
233    W: JsonRpcWsCodec,
234{
235    let endpoint = config.endpoint.clone();
236    let url = endpoint.to_string();
237
238    match &endpoint {
239        Endpoint::Ws(..) => {
240            let stream = karyon_net::tcp::connect(&endpoint, config.tcp_config.clone()).await?;
241            let layer = karyon_net::layers::ws::WsLayer::client(&url, codec);
242            Ok(ClientLayer::handshake(&layer, stream).await?)
243        }
244        #[cfg(feature = "tls")]
245        Endpoint::Wss(..) => {
246            let stream = karyon_net::tcp::connect(&endpoint, config.tcp_config.clone()).await?;
247            let tls_config = config.tls_config.as_ref().ok_or(Error::TLSConfigRequired)?;
248            let tls_layer = karyon_net::tls::TlsLayer::client(tls_config.clone());
249            let tls_stream = karyon_net::ClientLayer::handshake(&tls_layer, stream).await?;
250            let layer = karyon_net::layers::ws::WsLayer::client(&url, codec);
251            Ok(ClientLayer::handshake(&layer, tls_stream).await?)
252        }
253        _ => Err(Error::UnsupportedProtocol(endpoint.to_string())),
254    }
255}
256
257pub(super) async fn send_request<B, W, T>(
258    client: &Client<B, W>,
259    method: &str,
260    params: T,
261) -> Result<message::Response>
262where
263    B: JsonRpcCodec,
264    W: WsCodec,
265    T: Serialize + DeserializeOwned,
266{
267    let ClientBackend::Multiplexed {
268        message_dispatcher,
269        send_chan,
270        ..
271    } = &client.backend
272    else {
273        return Err(Error::InvalidState("not in multiplexed mode".into()));
274    };
275
276    let id: RequestID = random_32();
277    let request = message::Request {
278        jsonrpc: message::JSONRPC_VERSION.to_string(),
279        id: json!(id),
280        method: method.to_string(),
281        params: Some(json!(params)),
282    };
283
284    if client.disconnect.load(Ordering::Relaxed) {
285        return Err(Error::ClientDisconnected);
286    }
287    let req = serde_json::to_value(request)?;
288    send_chan.0.send(req).await?;
289
290    let rx = message_dispatcher.register(id).await;
291
292    let result = match client.config.timeout {
293        Some(t) => timeout(std::time::Duration::from_millis(t), rx.recv()).await?,
294        None => rx.recv().await,
295    };
296
297    let response = match result {
298        Ok(r) => r,
299        Err(err) => {
300            message_dispatcher.unregister(&id).await;
301            return Err(err.into());
302        }
303    };
304
305    if let Some(error) = response.error {
306        return Err(Error::SubscribeError(error.code, error.message));
307    }
308
309    let resp_id = response
310        .id
311        .as_ref()
312        .ok_or_else(|| Error::InvalidMsg("Missing response id".to_string()))?;
313    if *resp_id != id {
314        return Err(Error::InvalidMsg("Invalid response id".to_string()));
315    }
316
317    Ok(response)
318}