1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use async_channel::Sender;
5use async_trait::async_trait;
6use log::{debug, error};
7
8use karyon_core::async_util::{AsyncQueue, TaskGroup, TaskResult};
9use karyon_net::{FramedReader, FramedWriter};
10
11#[cfg(feature = "quic")]
12use karyon_core::async_runtime::io::{AsyncReadExt, AsyncWriteExt};
13#[cfg(feature = "quic")]
14use karyon_net::{framed, quic::QuicConn, StreamMux};
15
16#[cfg(feature = "quic")]
17use crate::message::StreamInit;
18
19use crate::{
20 codec::PeerNetMsgCodec,
21 conn_queue::QueuedConn,
22 message::{PeerNetCmd, PeerNetMsg, ProtocolMsg, ShutdownMsg},
23 peer::ConnDirection,
24 protocol::{ProtocolEvent, ProtocolID},
25 util::{decode, encode},
26 Error, Result,
27};
28
29const SEND_QUEUE_SIZE: usize = 128;
30const RECV_QUEUE_SIZE: usize = 128;
31
32#[async_trait]
35pub(crate) trait PeerConnection: Send + Sync {
36 async fn send(&self, proto_id: &ProtocolID, payload: Vec<u8>) -> Result<()>;
38 async fn recv(&self, proto_id: &ProtocolID) -> Result<ProtocolEvent>;
41 async fn shutdown(&self) -> Result<()>;
44}
45
46pub(crate) async fn from_queued(
49 queued: QueuedConn,
50 negotiated: &HashSet<ProtocolID>,
51 proto_ids: impl IntoIterator<Item = ProtocolID> + Clone,
52 task_group: &TaskGroup,
53 stop_chan: Sender<Result<()>>,
54) -> Result<Arc<dyn PeerConnection>> {
55 #[cfg(feature = "quic")]
56 if queued.quic_conn.is_some() {
57 let conn = MuxConnection::from_queued(queued, negotiated, proto_ids, task_group).await?;
58 return Ok(Arc::new(conn) as Arc<dyn PeerConnection>);
59 }
60
61 let _ = negotiated;
62 let conn = SingleConnection::from_queued(queued, proto_ids, task_group, stop_chan);
63 Ok(Arc::new(conn))
64}
65
66pub(crate) struct SingleConnection {
69 send_queue: Arc<AsyncQueue<PeerNetMsg>>,
70 recv_queues: HashMap<ProtocolID, Arc<AsyncQueue<ProtocolEvent>>>,
71}
72
73impl SingleConnection {
74 pub(crate) fn from_queued(
76 queued: QueuedConn,
77 proto_ids: impl IntoIterator<Item = ProtocolID>,
78 task_group: &TaskGroup,
79 stop_chan: Sender<Result<()>>,
80 ) -> Self {
81 let send_queue = AsyncQueue::new(SEND_QUEUE_SIZE);
82 let recv_queues = build_recv_queues(proto_ids);
83
84 spawn_writer_task(task_group, queued.writer, send_queue.clone());
85 spawn_demux_reader(task_group, queued.reader, recv_queues.clone(), stop_chan);
86
87 Self {
88 send_queue,
89 recv_queues,
90 }
91 }
92}
93
94#[async_trait]
95impl PeerConnection for SingleConnection {
96 async fn send(&self, proto_id: &ProtocolID, payload: Vec<u8>) -> Result<()> {
97 let proto_msg = ProtocolMsg {
98 protocol_id: proto_id.clone(),
99 payload,
100 };
101 let net_msg = PeerNetMsg::new(PeerNetCmd::Protocol, &proto_msg)?;
102 self.send_queue.push(net_msg).await;
103 Ok(())
104 }
105
106 async fn recv(&self, proto_id: &ProtocolID) -> Result<ProtocolEvent> {
107 match self.recv_queues.get(proto_id) {
108 Some(q) => Ok(q.recv().await),
109 None => Err(Error::UnsupportedProtocol(proto_id.clone())),
110 }
111 }
112
113 async fn shutdown(&self) -> Result<()> {
114 let m = PeerNetMsg::new(PeerNetCmd::Shutdown, ShutdownMsg(0))?;
115 self.send_queue.push(m).await;
116 broadcast_shutdown(&self.recv_queues).await;
117 Ok(())
118 }
119}
120
121#[cfg(feature = "quic")]
124pub(crate) struct MuxConnection {
125 send_queues: HashMap<ProtocolID, Arc<AsyncQueue<PeerNetMsg>>>,
126 recv_queues: HashMap<ProtocolID, Arc<AsyncQueue<ProtocolEvent>>>,
127 quic_conn: QuicConn,
128}
129
130#[cfg(feature = "quic")]
131impl MuxConnection {
132 pub(crate) async fn from_queued(
135 queued: QueuedConn,
136 negotiated: &HashSet<ProtocolID>,
137 proto_ids: impl IntoIterator<Item = ProtocolID>,
138 task_group: &TaskGroup,
139 ) -> Result<Self> {
140 let quic_conn = queued
141 .quic_conn
142 .ok_or_else(|| Error::InvalidMsg("MuxConnection requires a QUIC conn".into()))?;
143
144 let recv_queues = build_recv_queues(proto_ids);
145 let (send_queues, streams) =
146 setup_quic_streams(&quic_conn, &queued.direction, negotiated).await?;
147
148 for stream in streams {
149 let q_send = send_queues
150 .get(&stream.proto_id)
151 .expect("send queue installed in setup_quic_streams")
152 .clone();
153 let q_recv = recv_queues
154 .get(&stream.proto_id)
155 .expect("recv queue installed in build_recv_queues")
156 .clone();
157 spawn_writer_task(task_group, stream.writer, q_send);
158 spawn_quic_reader(task_group, stream.proto_id, stream.reader, q_recv);
159 }
160
161 Ok(Self {
162 send_queues,
163 recv_queues,
164 quic_conn,
165 })
166 }
167}
168
169#[cfg(feature = "quic")]
170#[async_trait]
171impl PeerConnection for MuxConnection {
172 async fn send(&self, proto_id: &ProtocolID, payload: Vec<u8>) -> Result<()> {
173 let proto_msg = ProtocolMsg {
174 protocol_id: proto_id.clone(),
175 payload,
176 };
177 let net_msg = PeerNetMsg::new(PeerNetCmd::Protocol, &proto_msg)?;
178 match self.send_queues.get(proto_id) {
179 Some(q) => {
180 q.push(net_msg).await;
181 Ok(())
182 }
183 None => Err(Error::UnsupportedProtocol(proto_id.clone())),
184 }
185 }
186
187 async fn recv(&self, proto_id: &ProtocolID) -> Result<ProtocolEvent> {
188 match self.recv_queues.get(proto_id) {
189 Some(q) => Ok(q.recv().await),
190 None => Err(Error::UnsupportedProtocol(proto_id.clone())),
191 }
192 }
193
194 async fn shutdown(&self) -> Result<()> {
195 self.quic_conn.close(0, b"shutdown");
196 broadcast_shutdown(&self.recv_queues).await;
197 Ok(())
198 }
199}
200
201#[cfg(feature = "quic")]
203struct MuxStream {
204 proto_id: ProtocolID,
205 reader: FramedReader<PeerNetMsgCodec>,
206 writer: FramedWriter<PeerNetMsgCodec>,
207}
208
209fn spawn_writer_task(
212 task_group: &TaskGroup,
213 mut writer: FramedWriter<PeerNetMsgCodec>,
214 queue: Arc<AsyncQueue<PeerNetMsg>>,
215) {
216 task_group.spawn(
217 async move {
218 loop {
219 let msg = queue.recv().await;
220 if writer.send_msg(msg).await.is_err() {
221 break;
222 }
223 }
224 Ok::<(), Error>(())
225 },
226 |res: TaskResult<Result<()>>| async move {
227 debug!("Peer writer task ended: {res}");
228 },
229 );
230}
231
232fn spawn_demux_reader(
236 task_group: &TaskGroup,
237 mut reader: FramedReader<PeerNetMsgCodec>,
238 recv_queues: HashMap<ProtocolID, Arc<AsyncQueue<ProtocolEvent>>>,
239 stop_chan: Sender<Result<()>>,
240) {
241 task_group.spawn(
242 async move {
243 loop {
244 let msg = match reader.recv_msg().await {
245 Ok(m) => m,
246 Err(e) => {
247 let _ = stop_chan.try_send(Err(e.into()));
248 break;
249 }
250 };
251 match msg.header.command {
252 PeerNetCmd::Protocol => {
253 let proto_msg: ProtocolMsg = match decode(&msg.payload) {
254 Ok((m, _)) => m,
255 Err(e) => {
256 let _ = stop_chan.try_send(Err(e));
257 break;
258 }
259 };
260 match recv_queues.get(&proto_msg.protocol_id) {
261 Some(q) => {
262 q.push(ProtocolEvent::Message(proto_msg.payload)).await;
263 }
264 None => {
265 error!("No recv queue for protocol {}", proto_msg.protocol_id);
266 }
267 }
268 }
269 PeerNetCmd::Shutdown => {
270 let _ = stop_chan.try_send(Err(Error::PeerShutdown));
271 break;
272 }
273 command => {
274 let _ = stop_chan.try_send(Err(Error::InvalidMsg(format!(
275 "Unexpected msg {command:?}"
276 ))));
277 break;
278 }
279 }
280 }
281 Ok::<(), Error>(())
282 },
283 |res: TaskResult<Result<()>>| async move {
284 debug!("Peer reader task ended: {res}");
285 },
286 );
287}
288
289#[cfg(feature = "quic")]
292fn spawn_quic_reader(
293 task_group: &TaskGroup,
294 proto_id: ProtocolID,
295 mut reader: FramedReader<PeerNetMsgCodec>,
296 recv_queue: Arc<AsyncQueue<ProtocolEvent>>,
297) {
298 task_group.spawn(
299 async move {
300 loop {
301 let msg = match reader.recv_msg().await {
302 Ok(m) => m,
303 Err(_) => break,
304 };
305 match msg.header.command {
306 PeerNetCmd::Protocol => {
307 let proto_msg: ProtocolMsg = decode(&msg.payload)?.0;
308 recv_queue
309 .push(ProtocolEvent::Message(proto_msg.payload))
310 .await;
311 }
312 PeerNetCmd::Shutdown => break,
313 cmd => {
314 error!("Unexpected msg on QUIC stream {proto_id}: {cmd:?}");
315 }
316 }
317 }
318 Ok::<(), Error>(())
319 },
320 |res: TaskResult<Result<()>>| async move {
321 debug!("QUIC stream reader task ended: {res}");
322 },
323 );
324}
325
326#[cfg(feature = "quic")]
329async fn setup_quic_streams(
330 quic_conn: &QuicConn,
331 direction: &ConnDirection,
332 negotiated: &HashSet<ProtocolID>,
333) -> Result<(
334 HashMap<ProtocolID, Arc<AsyncQueue<PeerNetMsg>>>,
335 Vec<MuxStream>,
336)> {
337 let mut queues = HashMap::new();
338 let mut streams = Vec::new();
339
340 match direction {
341 ConnDirection::Outbound => {
342 for proto_id in negotiated.iter() {
343 let mut stream = quic_conn.open_stream().await?;
344
345 let init = StreamInit {
346 protocol_id: proto_id.clone(),
347 };
348 let encoded = encode(&init)?;
349 stream.write_all(&encoded).await?;
350 stream.flush().await?;
351
352 let conn = framed(stream, PeerNetMsgCodec::new());
353 let (reader, writer) = conn.split();
354
355 let q = AsyncQueue::new(SEND_QUEUE_SIZE);
356 queues.insert(proto_id.clone(), q);
357 streams.push(MuxStream {
358 proto_id: proto_id.clone(),
359 reader,
360 writer,
361 });
362 }
363 }
364 ConnDirection::Inbound => {
365 let expected = negotiated.len();
366 let mut received = 0;
367
368 while received < expected {
369 let mut stream = quic_conn.accept_stream().await?;
370
371 let mut header_buf = vec![0u8; 256];
372 let n = stream.read(&mut header_buf).await?;
373 if n == 0 {
374 continue;
375 }
376
377 let (init, _): (StreamInit, _) = decode(&header_buf[..n])?;
378
379 if !negotiated.contains(&init.protocol_id) {
380 error!("Unsupported protocol: {}", init.protocol_id);
381 continue;
382 }
383
384 let conn = framed(stream, PeerNetMsgCodec::new());
385 let (reader, writer) = conn.split();
386
387 let q = AsyncQueue::new(SEND_QUEUE_SIZE);
388 queues.insert(init.protocol_id.clone(), q);
389 streams.push(MuxStream {
390 proto_id: init.protocol_id,
391 reader,
392 writer,
393 });
394
395 received += 1;
396 }
397 }
398 }
399
400 Ok((queues, streams))
401}
402
403fn build_recv_queues(
405 proto_ids: impl IntoIterator<Item = ProtocolID>,
406) -> HashMap<ProtocolID, Arc<AsyncQueue<ProtocolEvent>>> {
407 proto_ids
408 .into_iter()
409 .map(|id| (id, AsyncQueue::new(RECV_QUEUE_SIZE)))
410 .collect()
411}
412
413async fn broadcast_shutdown(queues: &HashMap<ProtocolID, Arc<AsyncQueue<ProtocolEvent>>>) {
416 for q in queues.values() {
417 q.push(ProtocolEvent::Shutdown).await;
418 }
419}