Skip to main content

karyon_jsonrpc/server/
mod.rs

1mod acceptor;
2pub mod builder;
3pub mod channel;
4mod dispatch;
5pub mod pubsub_service;
6pub mod service;
7
8#[cfg(feature = "quic")]
9mod quic;
10
11#[cfg(feature = "http")]
12mod http;
13
14use std::{collections::HashMap, sync::Arc};
15
16use log::{debug, error, info};
17
18use karyon_core::{
19    async_runtime::Executor,
20    async_util::{select, AsyncQueue, Either, TaskGroup, TaskResult},
21};
22
23use karyon_net::{Endpoint, MessageRx, MessageTx};
24
25#[cfg(feature = "tcp")]
26use karyon_net::tcp::{TcpConfig, TcpListener};
27#[cfg(feature = "tls")]
28use karyon_net::tls::TlsListener;
29#[cfg(all(feature = "unix", target_family = "unix"))]
30use karyon_net::unix::UnixListener;
31
32use crate::{
33    codec::JsonRpcCodec,
34    error::{Error, Result},
35    message,
36    server::{
37        acceptor::{AsyncAcceptor, StreamAcceptor},
38        channel::NewNotification,
39    },
40};
41
42#[cfg(feature = "ws")]
43use crate::{codec::JsonRpcWsCodec, server::acceptor::WsAcceptor};
44
45pub use builder::ServerBuilder;
46pub use channel::Channel;
47pub use pubsub_service::{PubSubRPCMethod, PubSubRPCService};
48pub use service::{RPCMethod, RPCService};
49
50pub const INVALID_REQUEST_ERROR_MSG: &str = "Invalid request";
51pub const FAILED_TO_PARSE_ERROR_MSG: &str = "Failed to parse";
52pub const METHOD_NOT_FOUND_ERROR_MSG: &str = "Method not found";
53pub const UNSUPPORTED_JSONRPC_VERSION: &str = "Unsupported jsonrpc version";
54
55const CHANNEL_SUBSCRIPTION_BUFFER_SIZE: usize = 100;
56
57/// Bound on the per-connection outbound response queue.
58const RESPONSE_QUEUE_SIZE: usize = 256;
59
60pub(crate) struct ServerConfig {
61    pub endpoint: Endpoint,
62    #[cfg(feature = "tcp")]
63    pub tcp_config: TcpConfig,
64    #[cfg(feature = "tls")]
65    pub tls_config: Option<karyon_net::tls::ServerTlsConfig>,
66    #[cfg(feature = "quic")]
67    pub quic_config: Option<karyon_net::quic::ServerQuicConfig>,
68    pub services: HashMap<String, Arc<dyn RPCService + 'static>>,
69    pub pubsub_services: HashMap<String, Arc<dyn PubSubRPCService + 'static>>,
70    /// User-customizable notification wire format.
71    /// Set via `ServerBuilder::with_notification_encoder`.
72    pub notification_encoder: fn(NewNotification) -> message::Notification,
73}
74
75/// One variant per backend family. Stream-based transports share
76/// a single `AsyncAcceptor`; QUIC and HTTP have their own loops.
77enum ServerBackend {
78    StreamAcceptor(Box<dyn AsyncAcceptor>),
79    #[cfg(feature = "quic")]
80    QuicEndpoint(karyon_net::quic::QuicEndpoint),
81    #[cfg(feature = "http")]
82    Http(http::HttpServer),
83}
84
85/// A JSON-RPC 2.0 server.
86pub struct Server {
87    backend: ServerBackend,
88    pub(crate) task_group: Arc<TaskGroup>,
89    pub(crate) config: ServerConfig,
90}
91
92impl Server {
93    pub fn start(self: Arc<Self>) {
94        self.task_group
95            .spawn(self.clone().start_block(), |_| async {});
96    }
97
98    pub async fn start_block(self: Arc<Self>) -> Result<()> {
99        if let Err(err) = self.accept_loop().await {
100            error!("Main accept loop stopped: {err}");
101            self.shutdown().await;
102        };
103        Ok(())
104    }
105
106    async fn accept_loop(self: &Arc<Self>) -> Result<()> {
107        match &self.backend {
108            ServerBackend::StreamAcceptor(acceptor) => loop {
109                if let Err(err) = acceptor.accept_and_handle(self).await {
110                    error!("Accept connection: {err}");
111                }
112            },
113            #[cfg(feature = "quic")]
114            ServerBackend::QuicEndpoint(endpoint) => loop {
115                match endpoint.accept().await {
116                    Ok(quic_conn) => {
117                        if let Err(err) = self.handle_quic_conn(quic_conn) {
118                            error!("Handle QUIC conn: {err}")
119                        }
120                    }
121                    Err(err) => {
122                        error!("Accept QUIC conn: {err}")
123                    }
124                }
125            },
126            #[cfg(feature = "http")]
127            ServerBackend::Http(http_server) => {
128                http::accept_loop(self.clone(), http_server).await?;
129                Ok(())
130            }
131        }
132    }
133
134    pub fn local_endpoint(&self) -> Result<Endpoint> {
135        match &self.backend {
136            ServerBackend::StreamAcceptor(acceptor) => acceptor.local_endpoint(),
137            #[cfg(feature = "quic")]
138            ServerBackend::QuicEndpoint(endpoint) => endpoint.local_endpoint().map_err(Error::from),
139            #[cfg(feature = "http")]
140            ServerBackend::Http(http_server) => http_server.local_endpoint(),
141        }
142    }
143
144    pub async fn shutdown(&self) {
145        self.task_group.cancel().await;
146    }
147
148    /// Handle a split message connection (framed byte stream or
149    /// WebSocket). Spawns a reader and writer task; the writer drains
150    /// both responses and notifications.
151    pub(crate) fn handle_message_conn<R, W>(
152        self: &Arc<Self>,
153        reader: R,
154        writer: W,
155        peer: Option<Endpoint>,
156    ) where
157        R: MessageRx<Message = serde_json::Value> + Send + 'static,
158        W: MessageTx<Message = serde_json::Value> + Send + 'static,
159    {
160        debug!("Handle connection {peer:?}");
161
162        let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_SUBSCRIPTION_BUFFER_SIZE);
163        let channel = Channel::new(ch_tx);
164        let queue = AsyncQueue::new(RESPONSE_QUEUE_SIZE);
165
166        let writer_chan = channel.clone();
167        self.task_group.spawn(
168            stream_writer_task(
169                writer,
170                queue.clone(),
171                ch_rx,
172                self.config.notification_encoder,
173            ),
174            |result: TaskResult<Result<()>>| async move {
175                if let TaskResult::Completed(Err(err)) = result {
176                    debug!("Writer stopped: {err}");
177                }
178                writer_chan.close();
179            },
180        );
181
182        let reader_chan = channel.clone();
183        self.task_group.spawn(
184            stream_reader_task(self.clone(), reader, queue, channel),
185            |result: TaskResult<Result<()>>| async move {
186                if let TaskResult::Completed(Err(err)) = result {
187                    debug!("Connection {peer:?} dropped: {err}");
188                } else {
189                    debug!("Connection {peer:?} dropped");
190                }
191                reader_chan.close();
192            },
193        );
194    }
195
196    async fn new_request(
197        self: &Arc<Self>,
198        queue: Arc<AsyncQueue<serde_json::Value>>,
199        channel: Arc<Channel>,
200        msg: serde_json::Value,
201    ) {
202        self.task_group.spawn(
203            request_task(self.clone(), queue, channel, msg),
204            |result: TaskResult<Result<()>>| async move {
205                if let TaskResult::Completed(Err(err)) = result {
206                    error!("Handle request: {err}");
207                }
208            },
209        );
210    }
211
212    pub(super) async fn init<B, W>(
213        config: ServerConfig,
214        ex: Option<Executor>,
215        byte_codec: B,
216        ws_codec: W,
217    ) -> Result<Arc<Self>>
218    where
219        B: JsonRpcCodec,
220        W: WsCodec,
221    {
222        let task_group = Arc::new(match ex {
223            Some(ex) => TaskGroup::with_executor(ex),
224            None => TaskGroup::new(),
225        });
226
227        let backend = create_backend(&config, byte_codec, ws_codec).await?;
228        info!("RPC server listens to the endpoint: {}", config.endpoint);
229
230        Ok(Arc::new(Server {
231            backend,
232            task_group,
233            config,
234        }))
235    }
236}
237
238/// Bound on the WebSocket codec generic. With the `ws` feature it
239/// requires `JsonRpcWsCodec`; otherwise it accepts any clonable type
240/// (the codec is unused) so callers can pass `JsonCodec` unchanged.
241#[cfg(feature = "ws")]
242pub trait WsCodec: JsonRpcWsCodec {}
243#[cfg(feature = "ws")]
244impl<T: JsonRpcWsCodec> WsCodec for T {}
245
246#[cfg(not(feature = "ws"))]
247pub trait WsCodec: Clone + Send + Sync + 'static {}
248#[cfg(not(feature = "ws"))]
249impl<T: Clone + Send + Sync + 'static> WsCodec for T {}
250
251async fn create_backend<B, W>(
252    config: &ServerConfig,
253    byte_codec: B,
254    ws_codec: W,
255) -> Result<ServerBackend>
256where
257    B: JsonRpcCodec,
258    W: WsCodec,
259{
260    let endpoint = config.endpoint.clone();
261    match endpoint {
262        #[cfg(feature = "http")]
263        Endpoint::Http(..) => {
264            #[cfg(feature = "http3")]
265            let http_server = match config.quic_config.clone() {
266                Some(quic_cfg) => http::HttpServer::new_h3(&endpoint, quic_cfg).await?,
267                None => http::HttpServer::new(&endpoint).await?,
268            };
269            #[cfg(not(feature = "http3"))]
270            let http_server = http::HttpServer::new(&endpoint).await?;
271            Ok(ServerBackend::Http(http_server))
272        }
273        #[cfg(feature = "quic")]
274        Endpoint::Quic(..) => match &config.quic_config {
275            Some(conf) => {
276                let quic_endpoint =
277                    karyon_net::quic::QuicEndpoint::listen(&endpoint, conf.clone()).await?;
278                Ok(ServerBackend::QuicEndpoint(quic_endpoint))
279            }
280            None => Err(Error::QUICConfigRequired),
281        },
282        #[cfg(feature = "tcp")]
283        Endpoint::Tcp(..) => {
284            let listener = TcpListener::bind(&endpoint, config.tcp_config.clone()).await?;
285            Ok(ServerBackend::StreamAcceptor(Box::new(StreamAcceptor {
286                listener: Box::new(listener),
287                codec: byte_codec,
288            })))
289        }
290        #[cfg(feature = "tls")]
291        Endpoint::Tls(..) => {
292            let tls_config = config.tls_config.as_ref().ok_or(Error::TLSConfigRequired)?;
293            let tcp_listener = TcpListener::bind(&endpoint, config.tcp_config.clone()).await?;
294            let listener = TlsListener::new(tcp_listener, tls_config.clone());
295            Ok(ServerBackend::StreamAcceptor(Box::new(StreamAcceptor {
296                listener: Box::new(listener),
297                codec: byte_codec,
298            })))
299        }
300        #[cfg(feature = "ws")]
301        Endpoint::Ws(..) => {
302            let listener = TcpListener::bind(&endpoint, config.tcp_config.clone()).await?;
303            let layer = Arc::new(karyon_net::layers::ws::WsLayer::server(ws_codec));
304            Ok(ServerBackend::StreamAcceptor(Box::new(WsAcceptor {
305                listener: Box::new(listener),
306                layer,
307                tls: false,
308            })))
309        }
310        #[cfg(all(feature = "ws", feature = "tls"))]
311        Endpoint::Wss(..) => {
312            let tls_config = config.tls_config.as_ref().ok_or(Error::TLSConfigRequired)?;
313            let tcp_listener = TcpListener::bind(&endpoint, config.tcp_config.clone()).await?;
314            let listener = TlsListener::new(tcp_listener, tls_config.clone());
315            let layer = Arc::new(karyon_net::layers::ws::WsLayer::server(ws_codec));
316            Ok(ServerBackend::StreamAcceptor(Box::new(WsAcceptor {
317                listener: Box::new(listener),
318                layer,
319                tls: true,
320            })))
321        }
322        #[cfg(all(feature = "unix", target_family = "unix"))]
323        Endpoint::Unix(..) => {
324            let listener = UnixListener::bind(&endpoint)?;
325            Ok(ServerBackend::StreamAcceptor(Box::new(StreamAcceptor {
326                listener: Box::new(listener),
327                codec: byte_codec,
328            })))
329        }
330        _ => Err(Error::UnsupportedProtocol(endpoint.to_string())),
331    }
332}
333
334async fn stream_writer_task<W>(
335    mut writer: W,
336    queue: Arc<AsyncQueue<serde_json::Value>>,
337    ch_rx: async_channel::Receiver<NewNotification>,
338    notification_encoder: fn(NewNotification) -> message::Notification,
339) -> Result<()>
340where
341    W: MessageTx<Message = serde_json::Value> + Send,
342{
343    loop {
344        match select(queue.recv(), ch_rx.recv()).await {
345            Either::Left(res) => {
346                writer.send_msg(res).await?;
347            }
348            Either::Right(notification) => {
349                let nt = notification?;
350                let notification = notification_encoder(nt);
351                debug!("--> {notification}");
352                writer.send_msg(serde_json::json!(notification)).await?;
353            }
354        }
355    }
356}
357
358async fn stream_reader_task<R>(
359    server: Arc<Server>,
360    mut reader: R,
361    queue: Arc<AsyncQueue<serde_json::Value>>,
362    channel: Arc<Channel>,
363) -> Result<()>
364where
365    R: MessageRx<Message = serde_json::Value> + Send,
366{
367    loop {
368        let msg = reader.recv_msg().await?;
369        server
370            .new_request(queue.clone(), channel.clone(), msg)
371            .await;
372    }
373}
374
375async fn request_task(
376    server: Arc<Server>,
377    queue: Arc<AsyncQueue<serde_json::Value>>,
378    channel: Arc<Channel>,
379    msg: serde_json::Value,
380) -> Result<()> {
381    let response = server.handle_request(Some(channel), msg).await;
382    debug!("--> {response}");
383    queue.push(serde_json::json!(response)).await;
384    Ok(())
385}