1use std::sync::{atomic::Ordering, Arc};
6
7use async_channel::Receiver;
8use log::{debug, error, info};
9use serde::{de::DeserializeOwned, Deserialize, Serialize};
10use serde_json::json;
11
12use karyon_core::{
13 async_util::{select, timeout, Either, TaskResult},
14 util::random_32,
15};
16
17use karyon_net::{framed, Endpoint, FramedConn, MessageRx, MessageTx};
18
19#[cfg(feature = "ws")]
20use karyon_net::{layers::ws::WsConn, ClientLayer};
21
22use crate::{
23 client::{
24 message_dispatcher::MessageDispatcher, subscriptions::Subscriptions, Client, ClientBackend,
25 ClientConfig, RequestID, WsCodec,
26 },
27 codec::JsonRpcCodec,
28 error::{Error, Result},
29 message,
30};
31
32#[cfg(feature = "ws")]
33use crate::codec::JsonRpcWsCodec;
34
35const OUTBOUND_BUFFER_SIZE: usize = 10;
39
40#[derive(Serialize, Deserialize)]
41#[serde(untagged)]
42enum NewMsg {
43 Notification(message::Notification),
44 Response(message::Response),
45}
46
47pub(super) async fn build_byte_backend<B>(
50 config: &ClientConfig,
51 codec: B,
52) -> Result<(ClientBackend, FramedConn<B>)>
53where
54 B: JsonRpcCodec,
55{
56 let conn = connect_byte(config, codec).await?;
57 let peer = conn.peer_endpoint();
58 info!("Successfully connected to the RPC server: {peer:?}");
59
60 let backend = build_backend_state(config);
61 Ok((backend, conn))
62}
63
64#[cfg(feature = "ws")]
66pub(super) async fn build_ws_backend<W>(
67 config: &ClientConfig,
68 codec: W,
69) -> Result<(ClientBackend, WsConn<W>)>
70where
71 W: JsonRpcWsCodec,
72{
73 let conn = connect_ws(config, codec).await?;
74 let peer = conn.peer_endpoint();
75 info!("Successfully connected to the RPC server: {peer:?}");
76
77 let backend = build_backend_state(config);
78 Ok((backend, conn))
79}
80
81fn build_backend_state(config: &ClientConfig) -> ClientBackend {
82 ClientBackend::Multiplexed {
83 message_dispatcher: MessageDispatcher::new(),
84 subscriptions: Subscriptions::new(config.subscription_buffer_size),
85 send_chan: async_channel::bounded(OUTBOUND_BUFFER_SIZE),
86 }
87}
88
89pub(super) fn start_io_loop<B, W, R, Wr>(client: &Arc<Client<B, W>>, reader: R, writer: Wr)
91where
92 B: JsonRpcCodec,
93 W: WsCodec,
94 R: MessageRx<Message = serde_json::Value> + Send + 'static,
95 Wr: MessageTx<Message = serde_json::Value> + Send + 'static,
96{
97 let this = client.clone();
98 client.task_group.spawn(
99 background_loop(client.clone(), reader, writer),
100 move |result: TaskResult<Result<()>>| async move {
101 if let TaskResult::Completed(Err(err)) = result {
102 error!("Client stopped: {err}");
103 }
104 this.disconnect.store(true, Ordering::Relaxed);
105 if let ClientBackend::Multiplexed {
106 subscriptions,
107 message_dispatcher,
108 ..
109 } = &this.backend
110 {
111 subscriptions.clear().await;
112 message_dispatcher.clear().await;
113 }
114 },
115 );
116}
117
118async fn background_loop<B, W, R, Wr>(
119 client: Arc<Client<B, W>>,
120 reader: R,
121 writer: Wr,
122) -> Result<()>
123where
124 B: JsonRpcCodec,
125 W: WsCodec,
126 R: MessageRx<Message = serde_json::Value> + Send,
127 Wr: MessageTx<Message = serde_json::Value> + Send,
128{
129 let ClientBackend::Multiplexed {
130 message_dispatcher,
131 subscriptions,
132 send_chan,
133 } = &client.backend
134 else {
135 return Err(Error::InvalidState("not in multiplexed mode".into()));
136 };
137
138 run_io_loop(
139 reader,
140 writer,
141 send_chan.1.clone(),
142 message_dispatcher,
143 subscriptions,
144 )
145 .await
146}
147
148async fn run_io_loop<R, Wr>(
149 mut reader: R,
150 mut writer: Wr,
151 outbound: Receiver<serde_json::Value>,
152 message_dispatcher: &MessageDispatcher,
153 subscriptions: &Subscriptions,
154) -> Result<()>
155where
156 R: MessageRx<Message = serde_json::Value> + Send,
157 Wr: MessageTx<Message = serde_json::Value> + Send,
158{
159 loop {
160 match select(outbound.recv(), reader.recv_msg()).await {
161 Either::Left(req) => {
162 writer.send_msg(req?).await?;
163 }
164 Either::Right(msg) => {
165 match handle_mux_msg(message_dispatcher, subscriptions, msg?).await {
166 Err(Error::SubscriptionBufferFull) => {
167 return Err(Error::SubscriptionBufferFull);
168 }
169 Err(err) => {
170 let ep = reader.peer_endpoint();
171 error!("Handle msg from {ep:?}: {err}");
172 }
173 Ok(_) => {}
174 }
175 }
176 }
177 }
178}
179
180async fn handle_mux_msg(
181 message_dispatcher: &MessageDispatcher,
182 subscriptions: &Subscriptions,
183 msg: serde_json::Value,
184) -> Result<()> {
185 match serde_json::from_value::<NewMsg>(msg.clone()) {
186 Ok(NewMsg::Response(res)) => {
187 debug!("<-- {res}");
188 message_dispatcher.dispatch(res).await
189 }
190 Ok(NewMsg::Notification(nt)) => {
191 debug!("<-- {nt}");
192 subscriptions.notify(nt).await
193 }
194 Err(err) => {
195 error!("Receive unexpected msg {msg}: {err}");
196 Err(Error::InvalidMsg("Unexpected msg".to_string()))
197 }
198 }
199}
200
201async fn connect_byte<B>(config: &ClientConfig, codec: B) -> Result<FramedConn<B>>
202where
203 B: JsonRpcCodec,
204{
205 let endpoint = config.endpoint.clone();
206
207 match &endpoint {
208 #[cfg(feature = "tcp")]
209 Endpoint::Tcp(..) => {
210 let stream = karyon_net::tcp::connect(&endpoint, config.tcp_config.clone()).await?;
211 Ok(framed(stream, codec))
212 }
213 #[cfg(feature = "tls")]
214 Endpoint::Tls(..) => {
215 let stream = karyon_net::tcp::connect(&endpoint, config.tcp_config.clone()).await?;
216 let tls_config = config.tls_config.as_ref().ok_or(Error::TLSConfigRequired)?;
217 let tls_layer = karyon_net::tls::TlsLayer::client(tls_config.clone());
218 let tls_stream = karyon_net::ClientLayer::handshake(&tls_layer, stream).await?;
219 Ok(framed(tls_stream, codec))
220 }
221 #[cfg(all(feature = "unix", target_family = "unix"))]
222 Endpoint::Unix(..) => {
223 let stream = karyon_net::unix::connect(&endpoint).await?;
224 Ok(framed(stream, codec))
225 }
226 _ => Err(Error::UnsupportedProtocol(endpoint.to_string())),
227 }
228}
229
230#[cfg(feature = "ws")]
231async fn connect_ws<W>(config: &ClientConfig, codec: W) -> Result<WsConn<W>>
232where
233 W: JsonRpcWsCodec,
234{
235 let endpoint = config.endpoint.clone();
236 let url = endpoint.to_string();
237
238 match &endpoint {
239 Endpoint::Ws(..) => {
240 let stream = karyon_net::tcp::connect(&endpoint, config.tcp_config.clone()).await?;
241 let layer = karyon_net::layers::ws::WsLayer::client(&url, codec);
242 Ok(ClientLayer::handshake(&layer, stream).await?)
243 }
244 #[cfg(feature = "tls")]
245 Endpoint::Wss(..) => {
246 let stream = karyon_net::tcp::connect(&endpoint, config.tcp_config.clone()).await?;
247 let tls_config = config.tls_config.as_ref().ok_or(Error::TLSConfigRequired)?;
248 let tls_layer = karyon_net::tls::TlsLayer::client(tls_config.clone());
249 let tls_stream = karyon_net::ClientLayer::handshake(&tls_layer, stream).await?;
250 let layer = karyon_net::layers::ws::WsLayer::client(&url, codec);
251 Ok(ClientLayer::handshake(&layer, tls_stream).await?)
252 }
253 _ => Err(Error::UnsupportedProtocol(endpoint.to_string())),
254 }
255}
256
257pub(super) async fn send_request<B, W, T>(
258 client: &Client<B, W>,
259 method: &str,
260 params: T,
261) -> Result<message::Response>
262where
263 B: JsonRpcCodec,
264 W: WsCodec,
265 T: Serialize + DeserializeOwned,
266{
267 let ClientBackend::Multiplexed {
268 message_dispatcher,
269 send_chan,
270 ..
271 } = &client.backend
272 else {
273 return Err(Error::InvalidState("not in multiplexed mode".into()));
274 };
275
276 let id: RequestID = random_32();
277 let request = message::Request {
278 jsonrpc: message::JSONRPC_VERSION.to_string(),
279 id: json!(id),
280 method: method.to_string(),
281 params: Some(json!(params)),
282 };
283
284 if client.disconnect.load(Ordering::Relaxed) {
285 return Err(Error::ClientDisconnected);
286 }
287 let req = serde_json::to_value(request)?;
288 send_chan.0.send(req).await?;
289
290 let rx = message_dispatcher.register(id).await;
291
292 let result = match client.config.timeout {
293 Some(t) => timeout(std::time::Duration::from_millis(t), rx.recv()).await?,
294 None => rx.recv().await,
295 };
296
297 let response = match result {
298 Ok(r) => r,
299 Err(err) => {
300 message_dispatcher.unregister(&id).await;
301 return Err(err.into());
302 }
303 };
304
305 if let Some(error) = response.error {
306 return Err(Error::SubscribeError(error.code, error.message));
307 }
308
309 let resp_id = response
310 .id
311 .as_ref()
312 .ok_or_else(|| Error::InvalidMsg("Missing response id".to_string()))?;
313 if *resp_id != id {
314 return Err(Error::InvalidMsg("Invalid response id".to_string()));
315 }
316
317 Ok(response)
318}