1pub mod builder;
2pub mod channel;
3pub mod pubsub_service;
4mod response_queue;
5pub mod service;
6
7use std::{collections::HashMap, sync::Arc};
8
9use log::{debug, error, info, trace, warn};
10
11use karyon_core::{
12 async_runtime::Executor,
13 async_util::{select, Either, TaskGroup, TaskResult},
14};
15
16#[cfg(feature = "tls")]
17use karyon_net::async_rustls::rustls;
18#[cfg(feature = "ws")]
19use karyon_net::ws::ServerWsConfig;
20use karyon_net::{Conn, Listener};
21
22#[cfg(feature = "tcp")]
23use crate::net::TcpConfig;
24
25use crate::{
26 codec::ClonableJsonCodec,
27 error::{Error, Result},
28 message,
29 net::Endpoint,
30 server::channel::NewNotification,
31};
32
33pub use builder::ServerBuilder;
34pub use channel::Channel;
35pub use pubsub_service::{PubSubRPCMethod, PubSubRPCService};
36pub use service::{RPCMethod, RPCService};
37
38use response_queue::ResponseQueue;
39
40pub const INVALID_REQUEST_ERROR_MSG: &str = "Invalid request";
41pub const FAILED_TO_PARSE_ERROR_MSG: &str = "Failed to parse";
42pub const METHOD_NOT_FOUND_ERROR_MSG: &str = "Method not found";
43pub const UNSUPPORTED_JSONRPC_VERSION: &str = "Unsupported jsonrpc version";
44
45const CHANNEL_SUBSCRIPTION_BUFFER_SIZE: usize = 100;
46
47struct NewRequest {
48 srvc_name: String,
49 method_name: String,
50 msg: message::Request,
51}
52
53enum SanityCheckResult {
54 NewReq(NewRequest),
55 ErrRes(message::Response),
56}
57
58fn default_notification_encoder(nt: NewNotification) -> message::Notification {
59 let params = Some(serde_json::json!(message::NotificationResult {
60 subscription: nt.sub_id,
61 result: Some(nt.result),
62 }));
63
64 message::Notification {
65 jsonrpc: message::JSONRPC_VERSION.to_string(),
66 method: nt.method,
67 params,
68 }
69}
70
71struct ServerConfig {
72 endpoint: Endpoint,
73 #[cfg(feature = "tcp")]
74 tcp_config: TcpConfig,
75 #[cfg(feature = "tls")]
76 tls_config: Option<rustls::ServerConfig>,
77 services: HashMap<String, Arc<dyn RPCService + 'static>>,
78 pubsub_services: HashMap<String, Arc<dyn PubSubRPCService + 'static>>,
79 notification_encoder: fn(NewNotification) -> message::Notification,
80}
81
82pub struct Server {
84 listener: Listener<serde_json::Value, Error>,
85 task_group: TaskGroup,
86 config: ServerConfig,
87}
88
89impl Server {
90 pub fn start(self: Arc<Self>) {
96 self.task_group
98 .spawn(self.clone().start_block(), |_| async {});
99 }
100
101 pub async fn start_block(self: Arc<Self>) -> Result<()> {
106 if let Err(err) = self.accept_loop().await {
107 error!("Main accept loop stopped: {err}");
108 self.shutdown().await;
109 };
110 Ok(())
111 }
112
113 async fn accept_loop(self: &Arc<Self>) -> Result<()> {
114 loop {
115 match self.listener.accept().await {
116 Ok(conn) => {
117 if let Err(err) = self.handle_conn(conn).await {
118 error!("Handle a new connection: {err}")
119 }
120 }
121 Err(err) => {
122 error!("Accept a new connection: {err}")
123 }
124 }
125 }
126 }
127
128 pub fn local_endpoint(&self) -> Result<Endpoint> {
130 self.listener.local_endpoint()
131 }
132
133 pub async fn shutdown(&self) {
135 self.task_group.cancel().await;
136 }
137
138 async fn handle_conn(self: &Arc<Self>, conn: Conn<serde_json::Value, Error>) -> Result<()> {
140 let endpoint: Option<Endpoint> = conn.peer_endpoint().ok();
141 debug!("Handle a new connection {endpoint:?}");
142
143 let conn = Arc::new(conn);
144
145 let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_SUBSCRIPTION_BUFFER_SIZE);
146 let channel = Channel::new(ch_tx);
148
149 let queue = ResponseQueue::new();
151
152 let chan = channel.clone();
153 let on_complete = |result: TaskResult<Result<()>>| async move {
154 if let TaskResult::Completed(Err(err)) = result {
155 debug!("Notification loop stopped: {err}");
156 }
157 chan.close();
159 };
160
161 let notification_encoder = self.config.notification_encoder;
162
163 self.task_group.spawn(
165 {
166 let conn = conn.clone();
167 let queue = queue.clone();
168 async move {
169 loop {
170 match select(queue.recv(), ch_rx.recv()).await {
173 Either::Left(res) => {
174 conn.send(res).await?;
175 }
176 Either::Right(notification) => {
177 let nt = notification?;
178 let notification = (notification_encoder)(nt);
179 debug!("--> {notification}");
180 conn.send(serde_json::json!(notification)).await?;
181 }
182 }
183 }
184 }
185 },
186 on_complete,
187 );
188
189 let chan = channel.clone();
190 let on_complete = |result: TaskResult<Result<()>>| async move {
191 if let TaskResult::Completed(Err(err)) = result {
192 error!("Connection {endpoint:?} dropped: {err}");
193 } else {
194 warn!("Connection {endpoint:?} dropped");
195 }
196 chan.close();
198 };
199
200 self.task_group.spawn(
202 {
203 let this = self.clone();
204 async move {
205 loop {
206 let msg = conn.recv().await?;
207 this.new_request(queue.clone(), channel.clone(), msg).await;
208 }
209 }
210 },
211 on_complete,
212 );
213
214 Ok(())
215 }
216
217 fn sanity_check(&self, request: serde_json::Value) -> SanityCheckResult {
218 let rpc_msg = match serde_json::from_value::<message::Request>(request) {
219 Ok(m) => m,
220 Err(_) => {
221 let response = message::Response {
222 error: Some(message::Error {
223 code: message::PARSE_ERROR_CODE,
224 message: FAILED_TO_PARSE_ERROR_MSG.to_string(),
225 data: None,
226 }),
227 ..Default::default()
228 };
229 return SanityCheckResult::ErrRes(response);
230 }
231 };
232
233 if rpc_msg.jsonrpc != message::JSONRPC_VERSION {
234 let response = message::Response {
235 error: Some(message::Error {
236 code: message::INVALID_REQUEST_ERROR_CODE,
237 message: UNSUPPORTED_JSONRPC_VERSION.to_string(),
238 data: None,
239 }),
240 id: Some(rpc_msg.id),
241 ..Default::default()
242 };
243 return SanityCheckResult::ErrRes(response);
244 }
245
246 debug!("<-- {rpc_msg}");
247
248 let srvc_method_str = rpc_msg.method.clone();
250 let srvc_method: Vec<&str> = srvc_method_str.split('.').collect();
251 if srvc_method.len() < 2 {
252 let response = message::Response {
253 error: Some(message::Error {
254 code: message::INVALID_REQUEST_ERROR_CODE,
255 message: INVALID_REQUEST_ERROR_MSG.to_string(),
256 data: None,
257 }),
258 id: Some(rpc_msg.id),
259 ..Default::default()
260 };
261 return SanityCheckResult::ErrRes(response);
262 }
263
264 let srvc_name = srvc_method[0].to_string();
265 let method_name = srvc_method[1..].join(".");
267
268 SanityCheckResult::NewReq(NewRequest {
269 srvc_name,
270 method_name,
271 msg: rpc_msg,
272 })
273 }
274
275 async fn new_request(
277 self: &Arc<Self>,
278 queue: Arc<ResponseQueue<serde_json::Value>>,
279 channel: Arc<Channel>,
280 msg: serde_json::Value,
281 ) {
282 trace!("--> new request {msg}");
283 let on_complete = |result: TaskResult<Result<()>>| async move {
284 if let TaskResult::Completed(Err(err)) = result {
285 error!("Handle a new request: {err}");
286 }
287 };
288
289 self.task_group.spawn(
292 {
293 let this = self.clone();
294 async move {
295 let response = this.handle_request(channel, msg).await;
296 debug!("--> {response}");
297 queue.push(serde_json::json!(response)).await;
298 Ok(())
299 }
300 },
301 on_complete,
302 );
303 }
304
305 async fn handle_request(
308 &self,
309 channel: Arc<Channel>,
310 msg: serde_json::Value,
311 ) -> message::Response {
312 let req = match self.sanity_check(msg) {
313 SanityCheckResult::NewReq(req) => req,
314 SanityCheckResult::ErrRes(res) => return res,
315 };
316
317 let mut response = message::Response {
318 error: None,
319 result: None,
320 id: Some(req.msg.id.clone()),
321 ..Default::default()
322 };
323
324 if let Some(service) = self.config.pubsub_services.get(&req.srvc_name) {
326 if let Some(method) = service.get_pubsub_method(&req.method_name) {
328 let params = req.msg.params.unwrap_or(serde_json::json!(()));
329 let result = method(channel, req.msg.method, params);
330 response.result = match result.await {
331 Ok(res) => Some(res),
332 Err(err) => return err.to_response(Some(req.msg.id), None),
333 };
334
335 return response;
336 }
337 }
338
339 if let Some(service) = self.config.services.get(&req.srvc_name) {
341 if let Some(method) = service.get_method(&req.method_name) {
343 let params = req.msg.params.unwrap_or(serde_json::json!(()));
344 let result = method(params);
345 response.result = match result.await {
346 Ok(res) => Some(res),
347 Err(err) => return err.to_response(Some(req.msg.id), None),
348 };
349
350 return response;
351 }
352 }
353
354 response.error = Some(message::Error {
355 code: message::METHOD_NOT_FOUND_ERROR_CODE,
356 message: METHOD_NOT_FOUND_ERROR_MSG.to_string(),
357 data: None,
358 });
359
360 response
361 }
362
363 async fn init(
365 config: ServerConfig,
366 ex: Option<Executor>,
367 codec: impl ClonableJsonCodec + 'static,
368 ) -> Result<Arc<Self>> {
369 let task_group = match ex {
370 Some(ex) => TaskGroup::with_executor(ex),
371 None => TaskGroup::new(),
372 };
373 let listener = Self::listen(&config, codec).await?;
374 info!("RPC server listens to the endpoint: {}", config.endpoint);
375
376 let server = Arc::new(Server {
377 listener,
378 task_group,
379 config,
380 });
381
382 Ok(server)
383 }
384
385 async fn listen(
386 config: &ServerConfig,
387 codec: impl ClonableJsonCodec + 'static,
388 ) -> Result<Listener<serde_json::Value, Error>> {
389 let endpoint = config.endpoint.clone();
390 let listener: Listener<serde_json::Value, Error> = match endpoint {
391 #[cfg(feature = "tcp")]
392 Endpoint::Tcp(..) => Box::new(
393 karyon_net::tcp::listen(&endpoint, config.tcp_config.clone(), codec).await?,
394 ),
395 #[cfg(feature = "tls")]
396 Endpoint::Tls(..) => match &config.tls_config {
397 Some(conf) => Box::new(
398 karyon_net::tls::listen(
399 &endpoint,
400 karyon_net::tls::ServerTlsConfig {
401 server_config: conf.clone(),
402 tcp_config: config.tcp_config.clone(),
403 },
404 codec,
405 )
406 .await?,
407 ),
408 None => return Err(Error::TLSConfigRequired),
409 },
410 #[cfg(feature = "ws")]
411 Endpoint::Ws(..) => {
412 let config = ServerWsConfig {
413 tcp_config: config.tcp_config.clone(),
414 wss_config: None,
415 };
416 Box::new(karyon_net::ws::listen(&endpoint, config, codec).await?)
417 }
418 #[cfg(all(feature = "ws", feature = "tls"))]
419 Endpoint::Wss(..) => match &config.tls_config {
420 Some(conf) => Box::new(
421 karyon_net::ws::listen(
422 &endpoint,
423 ServerWsConfig {
424 tcp_config: config.tcp_config.clone(),
425 wss_config: Some(karyon_net::ws::ServerWssConfig {
426 server_config: conf.clone(),
427 }),
428 },
429 codec,
430 )
431 .await?,
432 ),
433 None => return Err(Error::TLSConfigRequired),
434 },
435 #[cfg(all(feature = "unix", target_family = "unix"))]
436 Endpoint::Unix(..) => Box::new(karyon_net::unix::listen(
437 &endpoint,
438 Default::default(),
439 codec,
440 )?),
441
442 _ => return Err(Error::UnsupportedProtocol(endpoint.to_string())),
443 };
444
445 Ok(listener)
446 }
447}