Skip to main content

karyon_jsonrpc/server/
acceptor.rs

1//! Acceptors used by the stream-based and WebSocket backends.
2//! Accepts a connection, wraps it with a codec, and hands the split
3//! halves to the server.
4
5use std::sync::Arc;
6
7use async_trait::async_trait;
8
9use karyon_net::{framed, ByteStream, Endpoint};
10
11use crate::{
12    codec::JsonRpcCodec,
13    error::{Error, Result},
14    server::Server,
15};
16
17#[cfg(feature = "ws")]
18use karyon_net::{layers::ws::WsLayer, ServerLayer};
19
20#[cfg(feature = "ws")]
21use crate::codec::JsonRpcWsCodec;
22
23/// Produces framed connections and hands them off to the server.
24#[async_trait]
25pub(super) trait AsyncAcceptor: Send + Sync {
26    async fn accept_and_handle(&self, server: &Arc<Server>) -> Result<()>;
27    fn local_endpoint(&self) -> Result<Endpoint>;
28}
29
30/// A listener that produces byte streams.
31#[async_trait]
32pub(super) trait StreamListener: Send + Sync {
33    async fn accept(&self) -> karyon_net::Result<Box<dyn ByteStream>>;
34    fn local_endpoint(&self) -> karyon_net::Result<Endpoint>;
35}
36
37#[cfg(feature = "tcp")]
38#[async_trait]
39impl StreamListener for karyon_net::tcp::TcpListener {
40    async fn accept(&self) -> karyon_net::Result<Box<dyn ByteStream>> {
41        self.accept().await
42    }
43    fn local_endpoint(&self) -> karyon_net::Result<Endpoint> {
44        self.local_endpoint()
45    }
46}
47
48#[cfg(feature = "tls")]
49#[async_trait]
50impl StreamListener for karyon_net::tls::TlsListener {
51    async fn accept(&self) -> karyon_net::Result<Box<dyn ByteStream>> {
52        self.accept().await
53    }
54    fn local_endpoint(&self) -> karyon_net::Result<Endpoint> {
55        self.local_endpoint()
56    }
57}
58
59#[cfg(all(feature = "unix", target_family = "unix"))]
60#[async_trait]
61impl StreamListener for karyon_net::unix::UnixListener {
62    async fn accept(&self) -> karyon_net::Result<Box<dyn ByteStream>> {
63        self.accept().await
64    }
65    fn local_endpoint(&self) -> karyon_net::Result<Endpoint> {
66        self.local_endpoint()
67    }
68}
69
70/// Byte-stream acceptor: accepts a stream, wraps with a framed codec,
71/// and hands the split halves to the server.
72pub(super) struct StreamAcceptor<C> {
73    pub(super) listener: Box<dyn StreamListener>,
74    pub(super) codec: C,
75}
76
77#[async_trait]
78impl<C> AsyncAcceptor for StreamAcceptor<C>
79where
80    C: JsonRpcCodec,
81{
82    async fn accept_and_handle(&self, server: &Arc<Server>) -> Result<()> {
83        let stream = self.listener.accept().await?;
84        let conn = framed(stream, self.codec.clone());
85        let peer = conn.peer_endpoint();
86        let (reader, writer) = conn.split();
87        server.handle_message_conn(reader, writer, peer);
88        Ok(())
89    }
90    fn local_endpoint(&self) -> Result<Endpoint> {
91        self.listener.local_endpoint().map_err(Error::from)
92    }
93}
94
95/// WebSocket acceptor: accepts a byte stream, runs the WS handshake,
96/// and hands the split halves to the server.
97#[cfg(feature = "ws")]
98pub(super) struct WsAcceptor<W> {
99    pub(super) listener: Box<dyn StreamListener>,
100    pub(super) layer: Arc<WsLayer<W>>,
101    /// `true` for `wss://`, used when reporting `local_endpoint`.
102    pub(super) tls: bool,
103}
104
105#[cfg(feature = "ws")]
106#[async_trait]
107impl<W> AsyncAcceptor for WsAcceptor<W>
108where
109    W: JsonRpcWsCodec,
110{
111    async fn accept_and_handle(&self, server: &Arc<Server>) -> Result<()> {
112        let stream = self.listener.accept().await?;
113        let conn = ServerLayer::handshake(self.layer.as_ref(), stream).await?;
114        let peer = conn.peer_endpoint();
115        let (reader, writer) = conn.split();
116        server.handle_message_conn(reader, writer, peer);
117        Ok(())
118    }
119    fn local_endpoint(&self) -> Result<Endpoint> {
120        // The listener reports `tcp://...`; rewrite to the WS scheme so
121        // a client building from this endpoint runs the WS handshake.
122        let inner = self.listener.local_endpoint().map_err(Error::from)?;
123        let addr = std::net::SocketAddr::try_from(inner.clone()).map_err(Error::from)?;
124        let scheme = if self.tls { "wss" } else { "ws" };
125        format!("{scheme}://{addr}/")
126            .parse()
127            .map_err(|e: karyon_net::Error| Error::from(e))
128    }
129}