karyon_jsonrpc/client/
message_dispatcher.rs

1use std::collections::HashMap;
2
3use async_channel::{Receiver, Sender};
4
5use karyon_core::async_runtime::lock::Mutex;
6
7use crate::{
8    error::{Error, Result},
9    message,
10};
11
12use super::RequestID;
13
14/// Manages client requests
15pub(super) struct MessageDispatcher {
16    chans: Mutex<HashMap<RequestID, Sender<message::Response>>>,
17}
18
19impl MessageDispatcher {
20    /// Creates a new MessageDispatcher
21    pub(super) fn new() -> Self {
22        Self {
23            chans: Mutex::new(HashMap::new()),
24        }
25    }
26
27    /// Registers a new request with a given ID and returns a Receiver channel
28    /// to wait for the response.
29    pub(super) async fn register(&self, id: RequestID) -> Receiver<message::Response> {
30        let (tx, rx) = async_channel::bounded(1);
31        self.chans.lock().await.insert(id, tx);
32        rx
33    }
34
35    /// Unregisters the request with the provided ID
36    pub(super) async fn unregister(&self, id: &RequestID) {
37        self.chans.lock().await.remove(id);
38    }
39
40    /// Clear the registered channels.
41    pub(super) async fn clear(&self) {
42        let mut chans = self.chans.lock().await;
43        for (_, tx) in chans.iter() {
44            tx.close();
45        }
46        chans.clear();
47    }
48
49    /// Dispatches a response to the channel associated with the response's ID.
50    ///
51    /// If a channel is registered for the response's ID, the response is sent
52    /// through that channel. If no channel is found for the ID, returns an error.
53    pub(super) async fn dispatch(&self, res: message::Response) -> Result<()> {
54        let res_id = match res.id {
55            Some(ref rid) => rid.clone(),
56            None => {
57                return Err(Error::InvalidMsg("Response id is none".to_string()));
58            }
59        };
60        let id: RequestID = serde_json::from_value(res_id)?;
61        let val = self.chans.lock().await.remove(&id);
62        match val {
63            Some(tx) => tx.send(res).await.map_err(Error::from),
64            None => Err(Error::InvalidMsg("Receive unknown message".to_string())),
65        }
66    }
67}