1mod acceptor;
2pub mod builder;
3pub mod channel;
4mod dispatch;
5pub mod pubsub_service;
6pub mod service;
7
8#[cfg(feature = "quic")]
9mod quic;
10
11#[cfg(feature = "http")]
12mod http;
13
14use std::{collections::HashMap, sync::Arc};
15
16use log::{debug, error, info};
17
18use karyon_core::{
19 async_runtime::Executor,
20 async_util::{select, AsyncQueue, Either, TaskGroup, TaskResult},
21};
22
23use karyon_net::{Endpoint, MessageRx, MessageTx};
24
25#[cfg(feature = "tcp")]
26use karyon_net::tcp::{TcpConfig, TcpListener};
27#[cfg(feature = "tls")]
28use karyon_net::tls::TlsListener;
29#[cfg(all(feature = "unix", target_family = "unix"))]
30use karyon_net::unix::UnixListener;
31
32use crate::{
33 codec::JsonRpcCodec,
34 error::{Error, Result},
35 message,
36 server::{
37 acceptor::{AsyncAcceptor, StreamAcceptor},
38 channel::NewNotification,
39 },
40};
41
42#[cfg(feature = "ws")]
43use crate::{codec::JsonRpcWsCodec, server::acceptor::WsAcceptor};
44
45pub use builder::ServerBuilder;
46pub use channel::Channel;
47pub use pubsub_service::{PubSubRPCMethod, PubSubRPCService};
48pub use service::{RPCMethod, RPCService};
49
50pub const INVALID_REQUEST_ERROR_MSG: &str = "Invalid request";
51pub const FAILED_TO_PARSE_ERROR_MSG: &str = "Failed to parse";
52pub const METHOD_NOT_FOUND_ERROR_MSG: &str = "Method not found";
53pub const UNSUPPORTED_JSONRPC_VERSION: &str = "Unsupported jsonrpc version";
54
55const CHANNEL_SUBSCRIPTION_BUFFER_SIZE: usize = 100;
56
57const RESPONSE_QUEUE_SIZE: usize = 256;
59
60pub(crate) struct ServerConfig {
61 pub endpoint: Endpoint,
62 #[cfg(feature = "tcp")]
63 pub tcp_config: TcpConfig,
64 #[cfg(feature = "tls")]
65 pub tls_config: Option<karyon_net::tls::ServerTlsConfig>,
66 #[cfg(feature = "quic")]
67 pub quic_config: Option<karyon_net::quic::ServerQuicConfig>,
68 pub services: HashMap<String, Arc<dyn RPCService + 'static>>,
69 pub pubsub_services: HashMap<String, Arc<dyn PubSubRPCService + 'static>>,
70 pub notification_encoder: fn(NewNotification) -> message::Notification,
73}
74
75enum ServerBackend {
78 StreamAcceptor(Box<dyn AsyncAcceptor>),
79 #[cfg(feature = "quic")]
80 QuicEndpoint(karyon_net::quic::QuicEndpoint),
81 #[cfg(feature = "http")]
82 Http(http::HttpServer),
83}
84
85pub struct Server {
87 backend: ServerBackend,
88 pub(crate) task_group: Arc<TaskGroup>,
89 pub(crate) config: ServerConfig,
90}
91
92impl Server {
93 pub fn start(self: Arc<Self>) {
94 self.task_group
95 .spawn(self.clone().start_block(), |_| async {});
96 }
97
98 pub async fn start_block(self: Arc<Self>) -> Result<()> {
99 if let Err(err) = self.accept_loop().await {
100 error!("Main accept loop stopped: {err}");
101 self.shutdown().await;
102 };
103 Ok(())
104 }
105
106 async fn accept_loop(self: &Arc<Self>) -> Result<()> {
107 match &self.backend {
108 ServerBackend::StreamAcceptor(acceptor) => loop {
109 if let Err(err) = acceptor.accept_and_handle(self).await {
110 error!("Accept connection: {err}");
111 }
112 },
113 #[cfg(feature = "quic")]
114 ServerBackend::QuicEndpoint(endpoint) => loop {
115 match endpoint.accept().await {
116 Ok(quic_conn) => {
117 if let Err(err) = self.handle_quic_conn(quic_conn) {
118 error!("Handle QUIC conn: {err}")
119 }
120 }
121 Err(err) => {
122 error!("Accept QUIC conn: {err}")
123 }
124 }
125 },
126 #[cfg(feature = "http")]
127 ServerBackend::Http(http_server) => {
128 http::accept_loop(self.clone(), http_server).await?;
129 Ok(())
130 }
131 }
132 }
133
134 pub fn local_endpoint(&self) -> Result<Endpoint> {
135 match &self.backend {
136 ServerBackend::StreamAcceptor(acceptor) => acceptor.local_endpoint(),
137 #[cfg(feature = "quic")]
138 ServerBackend::QuicEndpoint(endpoint) => endpoint.local_endpoint().map_err(Error::from),
139 #[cfg(feature = "http")]
140 ServerBackend::Http(http_server) => http_server.local_endpoint(),
141 }
142 }
143
144 pub async fn shutdown(&self) {
145 self.task_group.cancel().await;
146 }
147
148 pub(crate) fn handle_message_conn<R, W>(
152 self: &Arc<Self>,
153 reader: R,
154 writer: W,
155 peer: Option<Endpoint>,
156 ) where
157 R: MessageRx<Message = serde_json::Value> + Send + 'static,
158 W: MessageTx<Message = serde_json::Value> + Send + 'static,
159 {
160 debug!("Handle connection {peer:?}");
161
162 let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_SUBSCRIPTION_BUFFER_SIZE);
163 let channel = Channel::new(ch_tx);
164 let queue = AsyncQueue::new(RESPONSE_QUEUE_SIZE);
165
166 let writer_chan = channel.clone();
167 self.task_group.spawn(
168 stream_writer_task(
169 writer,
170 queue.clone(),
171 ch_rx,
172 self.config.notification_encoder,
173 ),
174 |result: TaskResult<Result<()>>| async move {
175 if let TaskResult::Completed(Err(err)) = result {
176 debug!("Writer stopped: {err}");
177 }
178 writer_chan.close();
179 },
180 );
181
182 let reader_chan = channel.clone();
183 self.task_group.spawn(
184 stream_reader_task(self.clone(), reader, queue, channel),
185 |result: TaskResult<Result<()>>| async move {
186 if let TaskResult::Completed(Err(err)) = result {
187 debug!("Connection {peer:?} dropped: {err}");
188 } else {
189 debug!("Connection {peer:?} dropped");
190 }
191 reader_chan.close();
192 },
193 );
194 }
195
196 async fn new_request(
197 self: &Arc<Self>,
198 queue: Arc<AsyncQueue<serde_json::Value>>,
199 channel: Arc<Channel>,
200 msg: serde_json::Value,
201 ) {
202 self.task_group.spawn(
203 request_task(self.clone(), queue, channel, msg),
204 |result: TaskResult<Result<()>>| async move {
205 if let TaskResult::Completed(Err(err)) = result {
206 error!("Handle request: {err}");
207 }
208 },
209 );
210 }
211
212 pub(super) async fn init<B, W>(
213 config: ServerConfig,
214 ex: Option<Executor>,
215 byte_codec: B,
216 ws_codec: W,
217 ) -> Result<Arc<Self>>
218 where
219 B: JsonRpcCodec,
220 W: WsCodec,
221 {
222 let task_group = Arc::new(match ex {
223 Some(ex) => TaskGroup::with_executor(ex),
224 None => TaskGroup::new(),
225 });
226
227 let backend = create_backend(&config, byte_codec, ws_codec).await?;
228 info!("RPC server listens to the endpoint: {}", config.endpoint);
229
230 Ok(Arc::new(Server {
231 backend,
232 task_group,
233 config,
234 }))
235 }
236}
237
238#[cfg(feature = "ws")]
242pub trait WsCodec: JsonRpcWsCodec {}
243#[cfg(feature = "ws")]
244impl<T: JsonRpcWsCodec> WsCodec for T {}
245
246#[cfg(not(feature = "ws"))]
247pub trait WsCodec: Clone + Send + Sync + 'static {}
248#[cfg(not(feature = "ws"))]
249impl<T: Clone + Send + Sync + 'static> WsCodec for T {}
250
251async fn create_backend<B, W>(
252 config: &ServerConfig,
253 byte_codec: B,
254 ws_codec: W,
255) -> Result<ServerBackend>
256where
257 B: JsonRpcCodec,
258 W: WsCodec,
259{
260 let endpoint = config.endpoint.clone();
261 match endpoint {
262 #[cfg(feature = "http")]
263 Endpoint::Http(..) => {
264 #[cfg(feature = "http3")]
265 let http_server = match config.quic_config.clone() {
266 Some(quic_cfg) => http::HttpServer::new_h3(&endpoint, quic_cfg).await?,
267 None => http::HttpServer::new(&endpoint).await?,
268 };
269 #[cfg(not(feature = "http3"))]
270 let http_server = http::HttpServer::new(&endpoint).await?;
271 Ok(ServerBackend::Http(http_server))
272 }
273 #[cfg(feature = "quic")]
274 Endpoint::Quic(..) => match &config.quic_config {
275 Some(conf) => {
276 let quic_endpoint =
277 karyon_net::quic::QuicEndpoint::listen(&endpoint, conf.clone()).await?;
278 Ok(ServerBackend::QuicEndpoint(quic_endpoint))
279 }
280 None => Err(Error::QUICConfigRequired),
281 },
282 #[cfg(feature = "tcp")]
283 Endpoint::Tcp(..) => {
284 let listener = TcpListener::bind(&endpoint, config.tcp_config.clone()).await?;
285 Ok(ServerBackend::StreamAcceptor(Box::new(StreamAcceptor {
286 listener: Box::new(listener),
287 codec: byte_codec,
288 })))
289 }
290 #[cfg(feature = "tls")]
291 Endpoint::Tls(..) => {
292 let tls_config = config.tls_config.as_ref().ok_or(Error::TLSConfigRequired)?;
293 let tcp_listener = TcpListener::bind(&endpoint, config.tcp_config.clone()).await?;
294 let listener = TlsListener::new(tcp_listener, tls_config.clone());
295 Ok(ServerBackend::StreamAcceptor(Box::new(StreamAcceptor {
296 listener: Box::new(listener),
297 codec: byte_codec,
298 })))
299 }
300 #[cfg(feature = "ws")]
301 Endpoint::Ws(..) => {
302 let listener = TcpListener::bind(&endpoint, config.tcp_config.clone()).await?;
303 let layer = Arc::new(karyon_net::layers::ws::WsLayer::server(ws_codec));
304 Ok(ServerBackend::StreamAcceptor(Box::new(WsAcceptor {
305 listener: Box::new(listener),
306 layer,
307 tls: false,
308 })))
309 }
310 #[cfg(all(feature = "ws", feature = "tls"))]
311 Endpoint::Wss(..) => {
312 let tls_config = config.tls_config.as_ref().ok_or(Error::TLSConfigRequired)?;
313 let tcp_listener = TcpListener::bind(&endpoint, config.tcp_config.clone()).await?;
314 let listener = TlsListener::new(tcp_listener, tls_config.clone());
315 let layer = Arc::new(karyon_net::layers::ws::WsLayer::server(ws_codec));
316 Ok(ServerBackend::StreamAcceptor(Box::new(WsAcceptor {
317 listener: Box::new(listener),
318 layer,
319 tls: true,
320 })))
321 }
322 #[cfg(all(feature = "unix", target_family = "unix"))]
323 Endpoint::Unix(..) => {
324 let listener = UnixListener::bind(&endpoint)?;
325 Ok(ServerBackend::StreamAcceptor(Box::new(StreamAcceptor {
326 listener: Box::new(listener),
327 codec: byte_codec,
328 })))
329 }
330 _ => Err(Error::UnsupportedProtocol(endpoint.to_string())),
331 }
332}
333
334async fn stream_writer_task<W>(
335 mut writer: W,
336 queue: Arc<AsyncQueue<serde_json::Value>>,
337 ch_rx: async_channel::Receiver<NewNotification>,
338 notification_encoder: fn(NewNotification) -> message::Notification,
339) -> Result<()>
340where
341 W: MessageTx<Message = serde_json::Value> + Send,
342{
343 loop {
344 match select(queue.recv(), ch_rx.recv()).await {
345 Either::Left(res) => {
346 writer.send_msg(res).await?;
347 }
348 Either::Right(notification) => {
349 let nt = notification?;
350 let notification = notification_encoder(nt);
351 debug!("--> {notification}");
352 writer.send_msg(serde_json::json!(notification)).await?;
353 }
354 }
355 }
356}
357
358async fn stream_reader_task<R>(
359 server: Arc<Server>,
360 mut reader: R,
361 queue: Arc<AsyncQueue<serde_json::Value>>,
362 channel: Arc<Channel>,
363) -> Result<()>
364where
365 R: MessageRx<Message = serde_json::Value> + Send,
366{
367 loop {
368 let msg = reader.recv_msg().await?;
369 server
370 .new_request(queue.clone(), channel.clone(), msg)
371 .await;
372 }
373}
374
375async fn request_task(
376 server: Arc<Server>,
377 queue: Arc<AsyncQueue<serde_json::Value>>,
378 channel: Arc<Channel>,
379 msg: serde_json::Value,
380) -> Result<()> {
381 let response = server.handle_request(Some(channel), msg).await;
382 debug!("--> {response}");
383 queue.push(serde_json::json!(response)).await;
384 Ok(())
385}