karyon_jsonrpc/server/http/
h3.rs1use std::{collections::HashMap, sync::Arc};
4
5use bytes::{Buf, Bytes};
6use h3_quinn::Connection;
7use hyper::{Method, Request, Response, StatusCode};
8use log::{debug, error};
9
10use karyon_core::{async_runtime::lock::RwLock, async_util::TaskResult};
11
12use karyon_net::quic::{QuicConn, QuicEndpoint};
13
14use crate::{
15 error::{Error, Result},
16 message::{self, SubscriptionID},
17 server::{
18 channel::{Channel, NewNotification},
19 http::{ERR_BODY_TOO_LARGE, ERR_METHOD_NOT_ALLOWED, MAX_HTTP_BODY_SIZE},
20 Server, CHANNEL_SUBSCRIPTION_BUFFER_SIZE, FAILED_TO_PARSE_ERROR_MSG,
21 },
22};
23
24type H3Stream = h3::server::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>;
25
26type SubSenders = Arc<RwLock<HashMap<SubscriptionID, async_channel::Sender<NewNotification>>>>;
36
37pub(super) async fn accept_h3(server: &Arc<Server>, quic_ep: &QuicEndpoint) {
38 let quic_conn = match quic_ep.accept().await {
39 Ok(conn) => conn,
40 Err(err) => {
41 error!("Accept QUIC for HTTP/3: {err}");
42 return;
43 }
44 };
45
46 server.task_group.spawn(
47 serve_conn_task(server.clone(), quic_conn),
48 |_: TaskResult<Result<()>>| async {},
49 );
50}
51
52async fn serve_conn_task(server: Arc<Server>, quic_conn: QuicConn) -> Result<()> {
53 let peer = quic_conn
54 .peer_endpoint()
55 .map(|e| e.to_string())
56 .unwrap_or_default();
57 if let Err(err) = serve_conn(server, quic_conn).await {
58 debug!("HTTP/3 from {peer} closed: {err}");
59 }
60 Ok(())
61}
62
63async fn serve_conn(server: Arc<Server>, quic_conn: QuicConn) -> Result<()> {
64 let h3_conn = Connection::new(quic_conn.inner().clone());
65 let mut h3_server: h3::server::Connection<Connection, Bytes> =
66 h3::server::Connection::new(h3_conn)
67 .await
68 .map_err(|e| Error::HttpError(format!("H3 handshake: {e}")))?;
69
70 let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_SUBSCRIPTION_BUFFER_SIZE);
72 let channel = Channel::new(ch_tx);
73
74 let sub_senders: SubSenders = Arc::new(RwLock::new(HashMap::new()));
75
76 server.task_group.spawn(
79 dispatch_subs_task(ch_rx, sub_senders.clone()),
80 |_: TaskResult<Result<()>>| async {},
81 );
82
83 loop {
84 match h3_server.accept().await {
85 Ok(Some(resolver)) => {
86 let (req, stream) = match resolver.resolve_request().await {
87 Ok(v) => v,
88 Err(err) => {
89 error!("Resolve HTTP/3 request: {err}");
90 continue;
91 }
92 };
93 server.task_group.spawn(
94 handle_request_task(
95 server.clone(),
96 channel.clone(),
97 sub_senders.clone(),
98 req,
99 stream,
100 ),
101 |_: TaskResult<Result<()>>| async {},
102 );
103 }
104 Ok(None) => break,
105 Err(err) => {
106 debug!("HTTP/3 connection error: {err}");
107 break;
108 }
109 }
110 }
111
112 for (_, sender) in sub_senders.write().await.drain() {
113 sender.close();
114 }
115 channel.close();
116 Ok(())
117}
118
119async fn dispatch_subs_task(
120 ch_rx: async_channel::Receiver<NewNotification>,
121 sub_senders: SubSenders,
122) -> Result<()> {
123 while let Ok(nt) = ch_rx.recv().await {
124 let subs = sub_senders.read().await;
125 if let Some(sender) = subs.get(&nt.sub_id) {
126 let _ = sender.send(nt).await;
127 }
128 }
129 Ok(())
130}
131
132async fn handle_request_task(
133 server: Arc<Server>,
134 channel: Arc<Channel>,
135 sub_senders: SubSenders,
136 req: Request<()>,
137 stream: H3Stream,
138) -> Result<()> {
139 if let Err(err) = handle_h3_request(server, channel, sub_senders, req, stream).await {
140 error!("Handle HTTP/3 request: {err}");
141 }
142 Ok(())
143}
144
145async fn handle_h3_request(
146 server: Arc<Server>,
147 channel: Arc<Channel>,
148 sub_senders: SubSenders,
149 req: Request<()>,
150 mut stream: H3Stream,
151) -> Result<()> {
152 if req.method() != Method::POST {
153 h3_send(
154 &mut stream,
155 StatusCode::METHOD_NOT_ALLOWED,
156 ERR_METHOD_NOT_ALLOWED.as_bytes(),
157 )
158 .await?;
159 return Ok(());
160 }
161
162 let mut body = Vec::new();
163 while let Some(chunk) = stream.recv_data().await.map_err(h3_err)? {
164 body.extend_from_slice(Buf::chunk(&chunk));
165 if body.len() as u64 > MAX_HTTP_BODY_SIZE {
166 h3_send(
167 &mut stream,
168 StatusCode::PAYLOAD_TOO_LARGE,
169 ERR_BODY_TOO_LARGE.as_bytes(),
170 )
171 .await?;
172 return Ok(());
173 }
174 }
175
176 let msg: serde_json::Value = match serde_json::from_slice(&body) {
177 Ok(v) => v,
178 Err(_) => {
179 let resp = message::Response {
180 error: Some(message::Error {
181 code: message::PARSE_ERROR_CODE,
182 message: FAILED_TO_PARSE_ERROR_MSG.to_string(),
183 data: None,
184 }),
185 ..Default::default()
186 };
187 h3_send(
188 &mut stream,
189 StatusCode::OK,
190 &serde_json::to_vec(&resp).unwrap(),
191 )
192 .await?;
193 return Ok(());
194 }
195 };
196
197 let is_subscribe = msg
198 .get("method")
199 .and_then(|m| m.as_str())
200 .map(|m| m.ends_with("_subscribe") && !m.ends_with("_unsubscribe"))
201 .unwrap_or(false);
202
203 if is_subscribe {
204 return handle_h3_subscribe(server, sub_senders, msg, stream).await;
205 }
206
207 let response = server.handle_request(Some(channel.clone()), msg).await;
208 debug!("--> {response}");
209
210 if response.error.is_none() {
211 if let Ok(rpc_req) = serde_json::from_slice::<message::Request>(&body) {
212 if let Some(params) = &rpc_req.params {
213 if let Ok(sub_id) = serde_json::from_value::<SubscriptionID>(params.clone()) {
214 if let Some(sender) = sub_senders.write().await.remove(&sub_id) {
215 sender.close();
216 }
217 }
218 }
219 }
220 }
221
222 h3_send(
223 &mut stream,
224 StatusCode::OK,
225 &serde_json::to_vec(&response).unwrap(),
226 )
227 .await?;
228 Ok(())
229}
230
231async fn handle_h3_subscribe(
233 server: Arc<Server>,
234 sub_senders: SubSenders,
235 msg: serde_json::Value,
236 mut stream: H3Stream,
237) -> Result<()> {
238 let (sub_tx, sub_rx) = async_channel::bounded(CHANNEL_SUBSCRIPTION_BUFFER_SIZE);
239 let sub_channel = Channel::new(sub_tx.clone());
240
241 let response = server.handle_request(Some(sub_channel.clone()), msg).await;
242
243 if response.error.is_some() {
244 h3_send(
245 &mut stream,
246 StatusCode::OK,
247 &serde_json::to_vec(&response).unwrap(),
248 )
249 .await?;
250 return Ok(());
251 }
252
253 let sub_id = response
254 .result
255 .as_ref()
256 .and_then(|v| serde_json::from_value::<SubscriptionID>(v.clone()).ok())
257 .ok_or_else(|| Error::InvalidMsg("Missing subscription id".into()))?;
258
259 sub_senders.write().await.insert(sub_id, sub_tx);
260
261 let resp = Response::builder()
262 .status(StatusCode::OK)
263 .header("Content-Type", "application/json")
264 .body(())
265 .unwrap();
266 stream.send_response(resp).await.map_err(h3_err)?;
267
268 debug!("--> {response}");
269 let json = serde_json::to_vec(&response).unwrap();
270 stream.send_data(Bytes::from(json)).await.map_err(h3_err)?;
271
272 let encoder = server.config.notification_encoder;
273
274 while let Ok(nt) = sub_rx.recv().await {
275 let notification = encoder(nt);
276 debug!("--> {notification}");
277 let json = serde_json::to_vec(&serde_json::json!(notification)).unwrap();
278 if stream.send_data(Bytes::from(json)).await.is_err() {
279 break;
280 }
281 }
282
283 sub_senders.write().await.remove(&sub_id);
284 sub_channel.close();
285 let _ = stream.finish().await;
286 Ok(())
287}
288
289async fn h3_send(stream: &mut H3Stream, status: StatusCode, body: &[u8]) -> Result<()> {
290 let resp = Response::builder()
291 .status(status)
292 .header("Content-Type", "application/json")
293 .body(())
294 .map_err(|e| Error::HttpError(e.to_string()))?;
295 stream.send_response(resp).await.map_err(h3_err)?;
296 stream
297 .send_data(Bytes::from(body.to_vec()))
298 .await
299 .map_err(h3_err)?;
300 stream.finish().await.map_err(h3_err)?;
301 Ok(())
302}
303
304fn h3_err(e: impl std::fmt::Display) -> Error {
305 Error::HttpError(e.to_string())
306}