karyon_jsonrpc/server/
mod.rs

1pub mod builder;
2pub mod channel;
3pub mod pubsub_service;
4mod response_queue;
5pub mod service;
6
7use std::{collections::HashMap, sync::Arc};
8
9use log::{debug, error, info, trace, warn};
10
11use karyon_core::{
12    async_runtime::Executor,
13    async_util::{select, Either, TaskGroup, TaskResult},
14};
15
16#[cfg(feature = "tls")]
17use karyon_net::async_rustls::rustls;
18#[cfg(feature = "ws")]
19use karyon_net::ws::ServerWsConfig;
20use karyon_net::{Conn, Listener};
21
22#[cfg(feature = "tcp")]
23use crate::net::TcpConfig;
24
25use crate::{
26    codec::ClonableJsonCodec,
27    error::{Error, Result},
28    message,
29    net::Endpoint,
30    server::channel::NewNotification,
31};
32
33pub use builder::ServerBuilder;
34pub use channel::Channel;
35pub use pubsub_service::{PubSubRPCMethod, PubSubRPCService};
36pub use service::{RPCMethod, RPCService};
37
38use response_queue::ResponseQueue;
39
40pub const INVALID_REQUEST_ERROR_MSG: &str = "Invalid request";
41pub const FAILED_TO_PARSE_ERROR_MSG: &str = "Failed to parse";
42pub const METHOD_NOT_FOUND_ERROR_MSG: &str = "Method not found";
43pub const UNSUPPORTED_JSONRPC_VERSION: &str = "Unsupported jsonrpc version";
44
45const CHANNEL_SUBSCRIPTION_BUFFER_SIZE: usize = 100;
46
47struct NewRequest {
48    srvc_name: String,
49    method_name: String,
50    msg: message::Request,
51}
52
53enum SanityCheckResult {
54    NewReq(NewRequest),
55    ErrRes(message::Response),
56}
57
58fn default_notification_encoder(nt: NewNotification) -> message::Notification {
59    let params = Some(serde_json::json!(message::NotificationResult {
60        subscription: nt.sub_id,
61        result: Some(nt.result),
62    }));
63
64    message::Notification {
65        jsonrpc: message::JSONRPC_VERSION.to_string(),
66        method: nt.method,
67        params,
68    }
69}
70
71struct ServerConfig {
72    endpoint: Endpoint,
73    #[cfg(feature = "tcp")]
74    tcp_config: TcpConfig,
75    #[cfg(feature = "tls")]
76    tls_config: Option<rustls::ServerConfig>,
77    services: HashMap<String, Arc<dyn RPCService + 'static>>,
78    pubsub_services: HashMap<String, Arc<dyn PubSubRPCService + 'static>>,
79    notification_encoder: fn(NewNotification) -> message::Notification,
80}
81
82/// Represents an RPC server
83pub struct Server {
84    listener: Listener<serde_json::Value, Error>,
85    task_group: TaskGroup,
86    config: ServerConfig,
87}
88
89impl Server {
90    /// Starts the RPC server by spawning a new task for the main accept loop.
91    /// The accept loop listens for incoming connections.
92    ///
93    /// This function does not block the current thread. If you need the thread to block,
94    /// use the [`start_block`] method instead.
95    pub fn start(self: Arc<Self>) {
96        // Spawns a new task for the main accept loop
97        self.task_group
98            .spawn(self.clone().start_block(), |_| async {});
99    }
100
101    /// Starts the RPC server by running the main accept loop.
102    /// The accept loop listens for incoming connections and blocks the current thread.
103    ///
104    /// If you prefer a non-blocking implementation, use the [`start`] method instead.
105    pub async fn start_block(self: Arc<Self>) -> Result<()> {
106        if let Err(err) = self.accept_loop().await {
107            error!("Main accept loop stopped: {err}");
108            self.shutdown().await;
109        };
110        Ok(())
111    }
112
113    async fn accept_loop(self: &Arc<Self>) -> Result<()> {
114        loop {
115            match self.listener.accept().await {
116                Ok(conn) => {
117                    if let Err(err) = self.handle_conn(conn).await {
118                        error!("Handle a new connection: {err}")
119                    }
120                }
121                Err(err) => {
122                    error!("Accept a new connection: {err}")
123                }
124            }
125        }
126    }
127
128    /// Returns the local endpoint.
129    pub fn local_endpoint(&self) -> Result<Endpoint> {
130        self.listener.local_endpoint()
131    }
132
133    /// Shuts down the RPC server
134    pub async fn shutdown(&self) {
135        self.task_group.cancel().await;
136    }
137
138    /// Handles a new connection
139    async fn handle_conn(self: &Arc<Self>, conn: Conn<serde_json::Value, Error>) -> Result<()> {
140        let endpoint: Option<Endpoint> = conn.peer_endpoint().ok();
141        debug!("Handle a new connection {endpoint:?}");
142
143        let conn = Arc::new(conn);
144
145        let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_SUBSCRIPTION_BUFFER_SIZE);
146        // Create a new connection channel for managing subscriptions
147        let channel = Channel::new(ch_tx);
148
149        // Create a response queue
150        let queue = ResponseQueue::new();
151
152        let chan = channel.clone();
153        let on_complete = |result: TaskResult<Result<()>>| async move {
154            if let TaskResult::Completed(Err(err)) = result {
155                debug!("Notification loop stopped: {err}");
156            }
157            // Close the connection channel
158            chan.close();
159        };
160
161        let notification_encoder = self.config.notification_encoder;
162
163        // Start listening for new responses in the queue or new notifications
164        self.task_group.spawn(
165            {
166                let conn = conn.clone();
167                let queue = queue.clone();
168                async move {
169                    loop {
170                        // The select function will prioritize the first future if both futures are ready.
171                        // This gives priority to the responses in the response queue.
172                        match select(queue.recv(), ch_rx.recv()).await {
173                            Either::Left(res) => {
174                                conn.send(res).await?;
175                            }
176                            Either::Right(notification) => {
177                                let nt = notification?;
178                                let notification = (notification_encoder)(nt);
179                                debug!("--> {notification}");
180                                conn.send(serde_json::json!(notification)).await?;
181                            }
182                        }
183                    }
184                }
185            },
186            on_complete,
187        );
188
189        let chan = channel.clone();
190        let on_complete = |result: TaskResult<Result<()>>| async move {
191            if let TaskResult::Completed(Err(err)) = result {
192                error!("Connection {endpoint:?} dropped: {err}");
193            } else {
194                warn!("Connection {endpoint:?} dropped");
195            }
196            // Close the connection channel when the connection dropped
197            chan.close();
198        };
199
200        // Spawn a new task and wait for new requests.
201        self.task_group.spawn(
202            {
203                let this = self.clone();
204                async move {
205                    loop {
206                        let msg = conn.recv().await?;
207                        this.new_request(queue.clone(), channel.clone(), msg).await;
208                    }
209                }
210            },
211            on_complete,
212        );
213
214        Ok(())
215    }
216
217    fn sanity_check(&self, request: serde_json::Value) -> SanityCheckResult {
218        let rpc_msg = match serde_json::from_value::<message::Request>(request) {
219            Ok(m) => m,
220            Err(_) => {
221                let response = message::Response {
222                    error: Some(message::Error {
223                        code: message::PARSE_ERROR_CODE,
224                        message: FAILED_TO_PARSE_ERROR_MSG.to_string(),
225                        data: None,
226                    }),
227                    ..Default::default()
228                };
229                return SanityCheckResult::ErrRes(response);
230            }
231        };
232
233        if rpc_msg.jsonrpc != message::JSONRPC_VERSION {
234            let response = message::Response {
235                error: Some(message::Error {
236                    code: message::INVALID_REQUEST_ERROR_CODE,
237                    message: UNSUPPORTED_JSONRPC_VERSION.to_string(),
238                    data: None,
239                }),
240                id: Some(rpc_msg.id),
241                ..Default::default()
242            };
243            return SanityCheckResult::ErrRes(response);
244        }
245
246        debug!("<-- {rpc_msg}");
247
248        // Parse the service name and its method
249        let srvc_method_str = rpc_msg.method.clone();
250        let srvc_method: Vec<&str> = srvc_method_str.split('.').collect();
251        if srvc_method.len() < 2 {
252            let response = message::Response {
253                error: Some(message::Error {
254                    code: message::INVALID_REQUEST_ERROR_CODE,
255                    message: INVALID_REQUEST_ERROR_MSG.to_string(),
256                    data: None,
257                }),
258                id: Some(rpc_msg.id),
259                ..Default::default()
260            };
261            return SanityCheckResult::ErrRes(response);
262        }
263
264        let srvc_name = srvc_method[0].to_string();
265        // Method name is allowed to contain dots
266        let method_name = srvc_method[1..].join(".");
267
268        SanityCheckResult::NewReq(NewRequest {
269            srvc_name,
270            method_name,
271            msg: rpc_msg,
272        })
273    }
274
275    /// Spawns a new task for handling the new request
276    async fn new_request(
277        self: &Arc<Self>,
278        queue: Arc<ResponseQueue<serde_json::Value>>,
279        channel: Arc<Channel>,
280        msg: serde_json::Value,
281    ) {
282        trace!("--> new request {msg}");
283        let on_complete = |result: TaskResult<Result<()>>| async move {
284            if let TaskResult::Completed(Err(err)) = result {
285                error!("Handle a new request: {err}");
286            }
287        };
288
289        // Spawns a new task for handling the new request, and push the
290        // response to the response queue.
291        self.task_group.spawn(
292            {
293                let this = self.clone();
294                async move {
295                    let response = this.handle_request(channel, msg).await;
296                    debug!("--> {response}");
297                    queue.push(serde_json::json!(response)).await;
298                    Ok(())
299                }
300            },
301            on_complete,
302        );
303    }
304
305    /// Handles the new request, and returns an RPC Response that has either
306    /// an error or result
307    async fn handle_request(
308        &self,
309        channel: Arc<Channel>,
310        msg: serde_json::Value,
311    ) -> message::Response {
312        let req = match self.sanity_check(msg) {
313            SanityCheckResult::NewReq(req) => req,
314            SanityCheckResult::ErrRes(res) => return res,
315        };
316
317        let mut response = message::Response {
318            error: None,
319            result: None,
320            id: Some(req.msg.id.clone()),
321            ..Default::default()
322        };
323
324        // Check if the service exists in pubsub services list
325        if let Some(service) = self.config.pubsub_services.get(&req.srvc_name) {
326            // Check if the method exists within the service
327            if let Some(method) = service.get_pubsub_method(&req.method_name) {
328                let params = req.msg.params.unwrap_or(serde_json::json!(()));
329                let result = method(channel, req.msg.method, params);
330                response.result = match result.await {
331                    Ok(res) => Some(res),
332                    Err(err) => return err.to_response(Some(req.msg.id), None),
333                };
334
335                return response;
336            }
337        }
338
339        // Check if the service exists in services list
340        if let Some(service) = self.config.services.get(&req.srvc_name) {
341            // Check if the method exists within the service
342            if let Some(method) = service.get_method(&req.method_name) {
343                let params = req.msg.params.unwrap_or(serde_json::json!(()));
344                let result = method(params);
345                response.result = match result.await {
346                    Ok(res) => Some(res),
347                    Err(err) => return err.to_response(Some(req.msg.id), None),
348                };
349
350                return response;
351            }
352        }
353
354        response.error = Some(message::Error {
355            code: message::METHOD_NOT_FOUND_ERROR_CODE,
356            message: METHOD_NOT_FOUND_ERROR_MSG.to_string(),
357            data: None,
358        });
359
360        response
361    }
362
363    /// Initializes a new [`Server`] from the provided [`ServerConfig`]
364    async fn init(
365        config: ServerConfig,
366        ex: Option<Executor>,
367        codec: impl ClonableJsonCodec + 'static,
368    ) -> Result<Arc<Self>> {
369        let task_group = match ex {
370            Some(ex) => TaskGroup::with_executor(ex),
371            None => TaskGroup::new(),
372        };
373        let listener = Self::listen(&config, codec).await?;
374        info!("RPC server listens to the endpoint: {}", config.endpoint);
375
376        let server = Arc::new(Server {
377            listener,
378            task_group,
379            config,
380        });
381
382        Ok(server)
383    }
384
385    async fn listen(
386        config: &ServerConfig,
387        codec: impl ClonableJsonCodec + 'static,
388    ) -> Result<Listener<serde_json::Value, Error>> {
389        let endpoint = config.endpoint.clone();
390        let listener: Listener<serde_json::Value, Error> = match endpoint {
391            #[cfg(feature = "tcp")]
392            Endpoint::Tcp(..) => Box::new(
393                karyon_net::tcp::listen(&endpoint, config.tcp_config.clone(), codec).await?,
394            ),
395            #[cfg(feature = "tls")]
396            Endpoint::Tls(..) => match &config.tls_config {
397                Some(conf) => Box::new(
398                    karyon_net::tls::listen(
399                        &endpoint,
400                        karyon_net::tls::ServerTlsConfig {
401                            server_config: conf.clone(),
402                            tcp_config: config.tcp_config.clone(),
403                        },
404                        codec,
405                    )
406                    .await?,
407                ),
408                None => return Err(Error::TLSConfigRequired),
409            },
410            #[cfg(feature = "ws")]
411            Endpoint::Ws(..) => {
412                let config = ServerWsConfig {
413                    tcp_config: config.tcp_config.clone(),
414                    wss_config: None,
415                };
416                Box::new(karyon_net::ws::listen(&endpoint, config, codec).await?)
417            }
418            #[cfg(all(feature = "ws", feature = "tls"))]
419            Endpoint::Wss(..) => match &config.tls_config {
420                Some(conf) => Box::new(
421                    karyon_net::ws::listen(
422                        &endpoint,
423                        ServerWsConfig {
424                            tcp_config: config.tcp_config.clone(),
425                            wss_config: Some(karyon_net::ws::ServerWssConfig {
426                                server_config: conf.clone(),
427                            }),
428                        },
429                        codec,
430                    )
431                    .await?,
432                ),
433                None => return Err(Error::TLSConfigRequired),
434            },
435            #[cfg(all(feature = "unix", target_family = "unix"))]
436            Endpoint::Unix(..) => Box::new(karyon_net::unix::listen(
437                &endpoint,
438                Default::default(),
439                codec,
440            )?),
441
442            _ => return Err(Error::UnsupportedProtocol(endpoint.to_string())),
443        };
444
445        Ok(listener)
446    }
447}