Skip to main content

karyon_p2p/peer/
connection.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use async_channel::Sender;
5use async_trait::async_trait;
6use log::{debug, error};
7
8use karyon_core::async_util::{AsyncQueue, TaskGroup, TaskResult};
9use karyon_net::{FramedReader, FramedWriter};
10
11#[cfg(feature = "quic")]
12use karyon_core::async_runtime::io::{AsyncReadExt, AsyncWriteExt};
13#[cfg(feature = "quic")]
14use karyon_net::{framed, quic::QuicConn, StreamMux};
15
16#[cfg(feature = "quic")]
17use crate::message::StreamInit;
18
19use crate::{
20    codec::PeerNetMsgCodec,
21    conn_queue::QueuedConn,
22    message::{PeerNetCmd, PeerNetMsg, ProtocolMsg, ShutdownMsg},
23    peer::ConnDirection,
24    protocol::{ProtocolEvent, ProtocolID},
25    util::{decode, encode},
26    Error, Result,
27};
28
29const SEND_QUEUE_SIZE: usize = 128;
30const RECV_QUEUE_SIZE: usize = 128;
31
32/// Per-peer wire abstraction. Hides single-pipe (TCP/TLS) vs.
33/// stream-mux (QUIC) framing from the layers above.
34#[async_trait]
35pub(crate) trait PeerConnection: Send + Sync {
36    /// Send pre-encoded payload bytes for `proto_id`.
37    async fn send(&self, proto_id: &ProtocolID, payload: Vec<u8>) -> Result<()>;
38    /// Pop the next event for `proto_id`. Blocks until a message
39    /// arrives or shutdown is broadcast.
40    async fn recv(&self, proto_id: &ProtocolID) -> Result<ProtocolEvent>;
41    /// Graceful close. Pushes Shutdown to every recv queue and signals
42    /// the wire (transport-specific).
43    async fn shutdown(&self) -> Result<()>;
44}
45
46/// Build the right `PeerConnection` for the post-handshake `QueuedConn`.
47/// Picks `MuxConnection` when QUIC is in use, `SingleConnection` otherwise.
48pub(crate) async fn from_queued(
49    queued: QueuedConn,
50    negotiated: &HashSet<ProtocolID>,
51    proto_ids: impl IntoIterator<Item = ProtocolID> + Clone,
52    task_group: &TaskGroup,
53    stop_chan: Sender<Result<()>>,
54) -> Result<Arc<dyn PeerConnection>> {
55    #[cfg(feature = "quic")]
56    if queued.quic_conn.is_some() {
57        let conn = MuxConnection::from_queued(queued, negotiated, proto_ids, task_group).await?;
58        return Ok(Arc::new(conn) as Arc<dyn PeerConnection>);
59    }
60
61    let _ = negotiated;
62    let conn = SingleConnection::from_queued(queued, proto_ids, task_group, stop_chan);
63    Ok(Arc::new(conn))
64}
65
66/// TCP / TLS path: one shared writer drains `send_queue`; a single
67/// reader demuxes incoming `PeerNetMsg`s into the matching `recv_queues`.
68pub(crate) struct SingleConnection {
69    send_queue: Arc<AsyncQueue<PeerNetMsg>>,
70    recv_queues: HashMap<ProtocolID, Arc<AsyncQueue<ProtocolEvent>>>,
71}
72
73impl SingleConnection {
74    /// Spawn the writer + demux reader and return the connection.
75    pub(crate) fn from_queued(
76        queued: QueuedConn,
77        proto_ids: impl IntoIterator<Item = ProtocolID>,
78        task_group: &TaskGroup,
79        stop_chan: Sender<Result<()>>,
80    ) -> Self {
81        let send_queue = AsyncQueue::new(SEND_QUEUE_SIZE);
82        let recv_queues = build_recv_queues(proto_ids);
83
84        spawn_writer_task(task_group, queued.writer, send_queue.clone());
85        spawn_demux_reader(task_group, queued.reader, recv_queues.clone(), stop_chan);
86
87        Self {
88            send_queue,
89            recv_queues,
90        }
91    }
92}
93
94#[async_trait]
95impl PeerConnection for SingleConnection {
96    async fn send(&self, proto_id: &ProtocolID, payload: Vec<u8>) -> Result<()> {
97        let proto_msg = ProtocolMsg {
98            protocol_id: proto_id.clone(),
99            payload,
100        };
101        let net_msg = PeerNetMsg::new(PeerNetCmd::Protocol, &proto_msg)?;
102        self.send_queue.push(net_msg).await;
103        Ok(())
104    }
105
106    async fn recv(&self, proto_id: &ProtocolID) -> Result<ProtocolEvent> {
107        match self.recv_queues.get(proto_id) {
108            Some(q) => Ok(q.recv().await),
109            None => Err(Error::UnsupportedProtocol(proto_id.clone())),
110        }
111    }
112
113    async fn shutdown(&self) -> Result<()> {
114        let m = PeerNetMsg::new(PeerNetCmd::Shutdown, ShutdownMsg(0))?;
115        self.send_queue.push(m).await;
116        broadcast_shutdown(&self.recv_queues).await;
117        Ok(())
118    }
119}
120
121/// QUIC path: one stream per protocol, each with its own writer task.
122/// `send_queues` and `recv_queues` are keyed by protocol id.
123#[cfg(feature = "quic")]
124pub(crate) struct MuxConnection {
125    send_queues: HashMap<ProtocolID, Arc<AsyncQueue<PeerNetMsg>>>,
126    recv_queues: HashMap<ProtocolID, Arc<AsyncQueue<ProtocolEvent>>>,
127    quic_conn: QuicConn,
128}
129
130#[cfg(feature = "quic")]
131impl MuxConnection {
132    /// Open / accept one QUIC stream per negotiated protocol and spawn
133    /// a reader + writer task for each.
134    pub(crate) async fn from_queued(
135        queued: QueuedConn,
136        negotiated: &HashSet<ProtocolID>,
137        proto_ids: impl IntoIterator<Item = ProtocolID>,
138        task_group: &TaskGroup,
139    ) -> Result<Self> {
140        let quic_conn = queued
141            .quic_conn
142            .ok_or_else(|| Error::InvalidMsg("MuxConnection requires a QUIC conn".into()))?;
143
144        let recv_queues = build_recv_queues(proto_ids);
145        let (send_queues, streams) =
146            setup_quic_streams(&quic_conn, &queued.direction, negotiated).await?;
147
148        for stream in streams {
149            let q_send = send_queues
150                .get(&stream.proto_id)
151                .expect("send queue installed in setup_quic_streams")
152                .clone();
153            let q_recv = recv_queues
154                .get(&stream.proto_id)
155                .expect("recv queue installed in build_recv_queues")
156                .clone();
157            spawn_writer_task(task_group, stream.writer, q_send);
158            spawn_quic_reader(task_group, stream.proto_id, stream.reader, q_recv);
159        }
160
161        Ok(Self {
162            send_queues,
163            recv_queues,
164            quic_conn,
165        })
166    }
167}
168
169#[cfg(feature = "quic")]
170#[async_trait]
171impl PeerConnection for MuxConnection {
172    async fn send(&self, proto_id: &ProtocolID, payload: Vec<u8>) -> Result<()> {
173        let proto_msg = ProtocolMsg {
174            protocol_id: proto_id.clone(),
175            payload,
176        };
177        let net_msg = PeerNetMsg::new(PeerNetCmd::Protocol, &proto_msg)?;
178        match self.send_queues.get(proto_id) {
179            Some(q) => {
180                q.push(net_msg).await;
181                Ok(())
182            }
183            None => Err(Error::UnsupportedProtocol(proto_id.clone())),
184        }
185    }
186
187    async fn recv(&self, proto_id: &ProtocolID) -> Result<ProtocolEvent> {
188        match self.recv_queues.get(proto_id) {
189            Some(q) => Ok(q.recv().await),
190            None => Err(Error::UnsupportedProtocol(proto_id.clone())),
191        }
192    }
193
194    async fn shutdown(&self) -> Result<()> {
195        self.quic_conn.close(0, b"shutdown");
196        broadcast_shutdown(&self.recv_queues).await;
197        Ok(())
198    }
199}
200
201/// One QUIC protocol stream's halves, returned from `setup_quic_streams`.
202#[cfg(feature = "quic")]
203struct MuxStream {
204    proto_id: ProtocolID,
205    reader: FramedReader<PeerNetMsgCodec>,
206    writer: FramedWriter<PeerNetMsgCodec>,
207}
208
209/// Spawn a task that drains `queue` into `writer`. Exits when the
210/// writer fails (peer hung up).
211fn spawn_writer_task(
212    task_group: &TaskGroup,
213    mut writer: FramedWriter<PeerNetMsgCodec>,
214    queue: Arc<AsyncQueue<PeerNetMsg>>,
215) {
216    task_group.spawn(
217        async move {
218            loop {
219                let msg = queue.recv().await;
220                if writer.send_msg(msg).await.is_err() {
221                    break;
222                }
223            }
224            Ok::<(), Error>(())
225        },
226        |res: TaskResult<Result<()>>| async move {
227            debug!("Peer writer task ended: {res}");
228        },
229    );
230}
231
232/// Spawn the single-pipe reader task. Reads `PeerNetMsg`s, routes
233/// `Protocol` payloads into the matching recv queue, signals the peer
234/// via `stop_chan` on Shutdown / error.
235fn spawn_demux_reader(
236    task_group: &TaskGroup,
237    mut reader: FramedReader<PeerNetMsgCodec>,
238    recv_queues: HashMap<ProtocolID, Arc<AsyncQueue<ProtocolEvent>>>,
239    stop_chan: Sender<Result<()>>,
240) {
241    task_group.spawn(
242        async move {
243            loop {
244                let msg = match reader.recv_msg().await {
245                    Ok(m) => m,
246                    Err(e) => {
247                        let _ = stop_chan.try_send(Err(e.into()));
248                        break;
249                    }
250                };
251                match msg.header.command {
252                    PeerNetCmd::Protocol => {
253                        let proto_msg: ProtocolMsg = match decode(&msg.payload) {
254                            Ok((m, _)) => m,
255                            Err(e) => {
256                                let _ = stop_chan.try_send(Err(e));
257                                break;
258                            }
259                        };
260                        match recv_queues.get(&proto_msg.protocol_id) {
261                            Some(q) => {
262                                q.push(ProtocolEvent::Message(proto_msg.payload)).await;
263                            }
264                            None => {
265                                error!("No recv queue for protocol {}", proto_msg.protocol_id);
266                            }
267                        }
268                    }
269                    PeerNetCmd::Shutdown => {
270                        let _ = stop_chan.try_send(Err(Error::PeerShutdown));
271                        break;
272                    }
273                    command => {
274                        let _ = stop_chan.try_send(Err(Error::InvalidMsg(format!(
275                            "Unexpected msg {command:?}"
276                        ))));
277                        break;
278                    }
279                }
280            }
281            Ok::<(), Error>(())
282        },
283        |res: TaskResult<Result<()>>| async move {
284            debug!("Peer reader task ended: {res}");
285        },
286    );
287}
288
289/// Spawn a per-stream QUIC reader task. The stream is already keyed
290/// by protocol id, so messages go straight into `recv_queue`.
291#[cfg(feature = "quic")]
292fn spawn_quic_reader(
293    task_group: &TaskGroup,
294    proto_id: ProtocolID,
295    mut reader: FramedReader<PeerNetMsgCodec>,
296    recv_queue: Arc<AsyncQueue<ProtocolEvent>>,
297) {
298    task_group.spawn(
299        async move {
300            loop {
301                let msg = match reader.recv_msg().await {
302                    Ok(m) => m,
303                    Err(_) => break,
304                };
305                match msg.header.command {
306                    PeerNetCmd::Protocol => {
307                        let proto_msg: ProtocolMsg = decode(&msg.payload)?.0;
308                        recv_queue
309                            .push(ProtocolEvent::Message(proto_msg.payload))
310                            .await;
311                    }
312                    PeerNetCmd::Shutdown => break,
313                    cmd => {
314                        error!("Unexpected msg on QUIC stream {proto_id}: {cmd:?}");
315                    }
316                }
317            }
318            Ok::<(), Error>(())
319        },
320        |res: TaskResult<Result<()>>| async move {
321            debug!("QUIC stream reader task ended: {res}");
322        },
323    );
324}
325
326/// Open / accept one QUIC stream per negotiated protocol and return
327/// the per-stream send queues plus the reader/writer halves.
328#[cfg(feature = "quic")]
329async fn setup_quic_streams(
330    quic_conn: &QuicConn,
331    direction: &ConnDirection,
332    negotiated: &HashSet<ProtocolID>,
333) -> Result<(
334    HashMap<ProtocolID, Arc<AsyncQueue<PeerNetMsg>>>,
335    Vec<MuxStream>,
336)> {
337    let mut queues = HashMap::new();
338    let mut streams = Vec::new();
339
340    match direction {
341        ConnDirection::Outbound => {
342            for proto_id in negotiated.iter() {
343                let mut stream = quic_conn.open_stream().await?;
344
345                let init = StreamInit {
346                    protocol_id: proto_id.clone(),
347                };
348                let encoded = encode(&init)?;
349                stream.write_all(&encoded).await?;
350                stream.flush().await?;
351
352                let conn = framed(stream, PeerNetMsgCodec::new());
353                let (reader, writer) = conn.split();
354
355                let q = AsyncQueue::new(SEND_QUEUE_SIZE);
356                queues.insert(proto_id.clone(), q);
357                streams.push(MuxStream {
358                    proto_id: proto_id.clone(),
359                    reader,
360                    writer,
361                });
362            }
363        }
364        ConnDirection::Inbound => {
365            let expected = negotiated.len();
366            let mut received = 0;
367
368            while received < expected {
369                let mut stream = quic_conn.accept_stream().await?;
370
371                let mut header_buf = vec![0u8; 256];
372                let n = stream.read(&mut header_buf).await?;
373                if n == 0 {
374                    continue;
375                }
376
377                let (init, _): (StreamInit, _) = decode(&header_buf[..n])?;
378
379                if !negotiated.contains(&init.protocol_id) {
380                    error!("Unsupported protocol: {}", init.protocol_id);
381                    continue;
382                }
383
384                let conn = framed(stream, PeerNetMsgCodec::new());
385                let (reader, writer) = conn.split();
386
387                let q = AsyncQueue::new(SEND_QUEUE_SIZE);
388                queues.insert(init.protocol_id.clone(), q);
389                streams.push(MuxStream {
390                    proto_id: init.protocol_id,
391                    reader,
392                    writer,
393                });
394
395                received += 1;
396            }
397        }
398    }
399
400    Ok((queues, streams))
401}
402
403/// One bounded recv queue per protocol id.
404fn build_recv_queues(
405    proto_ids: impl IntoIterator<Item = ProtocolID>,
406) -> HashMap<ProtocolID, Arc<AsyncQueue<ProtocolEvent>>> {
407    proto_ids
408        .into_iter()
409        .map(|id| (id, AsyncQueue::new(RECV_QUEUE_SIZE)))
410        .collect()
411}
412
413/// Push `ProtocolEvent::Shutdown` into every recv queue, waking
414/// blocked `recv` callers.
415async fn broadcast_shutdown(queues: &HashMap<ProtocolID, Arc<AsyncQueue<ProtocolEvent>>>) {
416    for q in queues.values() {
417        q.push(ProtocolEvent::Shutdown).await;
418    }
419}