Skip to main content

karyon_jsonrpc/server/http/
h3.rs

1//! HTTP/3 over QUIC via h3 / h3-quinn.
2
3use std::{collections::HashMap, sync::Arc};
4
5use bytes::{Buf, Bytes};
6use h3_quinn::Connection;
7use hyper::{Method, Request, Response, StatusCode};
8use log::{debug, error};
9
10use karyon_core::{async_runtime::lock::RwLock, async_util::TaskResult};
11
12use karyon_net::quic::{QuicConn, QuicEndpoint};
13
14use crate::{
15    error::{Error, Result},
16    message::{self, SubscriptionID},
17    server::{
18        channel::{Channel, NewNotification},
19        http::{ERR_BODY_TOO_LARGE, ERR_METHOD_NOT_ALLOWED, MAX_HTTP_BODY_SIZE},
20        Server, CHANNEL_SUBSCRIPTION_BUFFER_SIZE, FAILED_TO_PARSE_ERROR_MSG,
21    },
22};
23
24type H3Stream = h3::server::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>;
25
26/// Map from `SubscriptionID` to the channel sender that feeds that
27/// subscription's HTTP/3 reply stream. Used by `dispatch_subs_task`
28/// to look up the right sender for each incoming `NewNotification`
29/// and forward the notification to its stream.
30///
31/// Inserted on subscribe, removed on unsubscribe (or when the stream
32/// ends). `RwLock` because the dispatcher reads it on every
33/// notification (frequent), while subscribe / unsubscribe write
34/// rarely.
35type SubSenders = Arc<RwLock<HashMap<SubscriptionID, async_channel::Sender<NewNotification>>>>;
36
37pub(super) async fn accept_h3(server: &Arc<Server>, quic_ep: &QuicEndpoint) {
38    let quic_conn = match quic_ep.accept().await {
39        Ok(conn) => conn,
40        Err(err) => {
41            error!("Accept QUIC for HTTP/3: {err}");
42            return;
43        }
44    };
45
46    server.task_group.spawn(
47        serve_conn_task(server.clone(), quic_conn),
48        |_: TaskResult<Result<()>>| async {},
49    );
50}
51
52async fn serve_conn_task(server: Arc<Server>, quic_conn: QuicConn) -> Result<()> {
53    let peer = quic_conn
54        .peer_endpoint()
55        .map(|e| e.to_string())
56        .unwrap_or_default();
57    if let Err(err) = serve_conn(server, quic_conn).await {
58        debug!("HTTP/3 from {peer} closed: {err}");
59    }
60    Ok(())
61}
62
63async fn serve_conn(server: Arc<Server>, quic_conn: QuicConn) -> Result<()> {
64    let h3_conn = Connection::new(quic_conn.inner().clone());
65    let mut h3_server: h3::server::Connection<Connection, Bytes> =
66        h3::server::Connection::new(h3_conn)
67            .await
68            .map_err(|e| Error::HttpError(format!("H3 handshake: {e}")))?;
69
70    // Per-connection channel feeds all pubsub output.
71    let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_SUBSCRIPTION_BUFFER_SIZE);
72    let channel = Channel::new(ch_tx);
73
74    let sub_senders: SubSenders = Arc::new(RwLock::new(HashMap::new()));
75
76    // Dispatcher task: route notifications from the connection-wide
77    // channel to the right per-subscription sender.
78    server.task_group.spawn(
79        dispatch_subs_task(ch_rx, sub_senders.clone()),
80        |_: TaskResult<Result<()>>| async {},
81    );
82
83    loop {
84        match h3_server.accept().await {
85            Ok(Some(resolver)) => {
86                let (req, stream) = match resolver.resolve_request().await {
87                    Ok(v) => v,
88                    Err(err) => {
89                        error!("Resolve HTTP/3 request: {err}");
90                        continue;
91                    }
92                };
93                server.task_group.spawn(
94                    handle_request_task(
95                        server.clone(),
96                        channel.clone(),
97                        sub_senders.clone(),
98                        req,
99                        stream,
100                    ),
101                    |_: TaskResult<Result<()>>| async {},
102                );
103            }
104            Ok(None) => break,
105            Err(err) => {
106                debug!("HTTP/3 connection error: {err}");
107                break;
108            }
109        }
110    }
111
112    for (_, sender) in sub_senders.write().await.drain() {
113        sender.close();
114    }
115    channel.close();
116    Ok(())
117}
118
119async fn dispatch_subs_task(
120    ch_rx: async_channel::Receiver<NewNotification>,
121    sub_senders: SubSenders,
122) -> Result<()> {
123    while let Ok(nt) = ch_rx.recv().await {
124        let subs = sub_senders.read().await;
125        if let Some(sender) = subs.get(&nt.sub_id) {
126            let _ = sender.send(nt).await;
127        }
128    }
129    Ok(())
130}
131
132async fn handle_request_task(
133    server: Arc<Server>,
134    channel: Arc<Channel>,
135    sub_senders: SubSenders,
136    req: Request<()>,
137    stream: H3Stream,
138) -> Result<()> {
139    if let Err(err) = handle_h3_request(server, channel, sub_senders, req, stream).await {
140        error!("Handle HTTP/3 request: {err}");
141    }
142    Ok(())
143}
144
145async fn handle_h3_request(
146    server: Arc<Server>,
147    channel: Arc<Channel>,
148    sub_senders: SubSenders,
149    req: Request<()>,
150    mut stream: H3Stream,
151) -> Result<()> {
152    if req.method() != Method::POST {
153        h3_send(
154            &mut stream,
155            StatusCode::METHOD_NOT_ALLOWED,
156            ERR_METHOD_NOT_ALLOWED.as_bytes(),
157        )
158        .await?;
159        return Ok(());
160    }
161
162    let mut body = Vec::new();
163    while let Some(chunk) = stream.recv_data().await.map_err(h3_err)? {
164        body.extend_from_slice(Buf::chunk(&chunk));
165        if body.len() as u64 > MAX_HTTP_BODY_SIZE {
166            h3_send(
167                &mut stream,
168                StatusCode::PAYLOAD_TOO_LARGE,
169                ERR_BODY_TOO_LARGE.as_bytes(),
170            )
171            .await?;
172            return Ok(());
173        }
174    }
175
176    let msg: serde_json::Value = match serde_json::from_slice(&body) {
177        Ok(v) => v,
178        Err(_) => {
179            let resp = message::Response {
180                error: Some(message::Error {
181                    code: message::PARSE_ERROR_CODE,
182                    message: FAILED_TO_PARSE_ERROR_MSG.to_string(),
183                    data: None,
184                }),
185                ..Default::default()
186            };
187            h3_send(
188                &mut stream,
189                StatusCode::OK,
190                &serde_json::to_vec(&resp).unwrap(),
191            )
192            .await?;
193            return Ok(());
194        }
195    };
196
197    let is_subscribe = msg
198        .get("method")
199        .and_then(|m| m.as_str())
200        .map(|m| m.ends_with("_subscribe") && !m.ends_with("_unsubscribe"))
201        .unwrap_or(false);
202
203    if is_subscribe {
204        return handle_h3_subscribe(server, sub_senders, msg, stream).await;
205    }
206
207    let response = server.handle_request(Some(channel.clone()), msg).await;
208    debug!("--> {response}");
209
210    if response.error.is_none() {
211        if let Ok(rpc_req) = serde_json::from_slice::<message::Request>(&body) {
212            if let Some(params) = &rpc_req.params {
213                if let Ok(sub_id) = serde_json::from_value::<SubscriptionID>(params.clone()) {
214                    if let Some(sender) = sub_senders.write().await.remove(&sub_id) {
215                        sender.close();
216                    }
217                }
218            }
219        }
220    }
221
222    h3_send(
223        &mut stream,
224        StatusCode::OK,
225        &serde_json::to_vec(&response).unwrap(),
226    )
227    .await?;
228    Ok(())
229}
230
231/// Subscribe: respond, then stream notifications on the same stream.
232async fn handle_h3_subscribe(
233    server: Arc<Server>,
234    sub_senders: SubSenders,
235    msg: serde_json::Value,
236    mut stream: H3Stream,
237) -> Result<()> {
238    let (sub_tx, sub_rx) = async_channel::bounded(CHANNEL_SUBSCRIPTION_BUFFER_SIZE);
239    let sub_channel = Channel::new(sub_tx.clone());
240
241    let response = server.handle_request(Some(sub_channel.clone()), msg).await;
242
243    if response.error.is_some() {
244        h3_send(
245            &mut stream,
246            StatusCode::OK,
247            &serde_json::to_vec(&response).unwrap(),
248        )
249        .await?;
250        return Ok(());
251    }
252
253    let sub_id = response
254        .result
255        .as_ref()
256        .and_then(|v| serde_json::from_value::<SubscriptionID>(v.clone()).ok())
257        .ok_or_else(|| Error::InvalidMsg("Missing subscription id".into()))?;
258
259    sub_senders.write().await.insert(sub_id, sub_tx);
260
261    let resp = Response::builder()
262        .status(StatusCode::OK)
263        .header("Content-Type", "application/json")
264        .body(())
265        .unwrap();
266    stream.send_response(resp).await.map_err(h3_err)?;
267
268    debug!("--> {response}");
269    let json = serde_json::to_vec(&response).unwrap();
270    stream.send_data(Bytes::from(json)).await.map_err(h3_err)?;
271
272    let encoder = server.config.notification_encoder;
273
274    while let Ok(nt) = sub_rx.recv().await {
275        let notification = encoder(nt);
276        debug!("--> {notification}");
277        let json = serde_json::to_vec(&serde_json::json!(notification)).unwrap();
278        if stream.send_data(Bytes::from(json)).await.is_err() {
279            break;
280        }
281    }
282
283    sub_senders.write().await.remove(&sub_id);
284    sub_channel.close();
285    let _ = stream.finish().await;
286    Ok(())
287}
288
289async fn h3_send(stream: &mut H3Stream, status: StatusCode, body: &[u8]) -> Result<()> {
290    let resp = Response::builder()
291        .status(status)
292        .header("Content-Type", "application/json")
293        .body(())
294        .map_err(|e| Error::HttpError(e.to_string()))?;
295    stream.send_response(resp).await.map_err(h3_err)?;
296    stream
297        .send_data(Bytes::from(body.to_vec()))
298        .await
299        .map_err(h3_err)?;
300    stream.finish().await.map_err(h3_err)?;
301    Ok(())
302}
303
304fn h3_err(e: impl std::fmt::Display) -> Error {
305    Error::HttpError(e.to_string())
306}