karyon_core/
pubsub.rs

1use std::{collections::HashMap, sync::Arc};
2
3use futures_util::stream::{FuturesUnordered, StreamExt};
4use log::error;
5
6use crate::{async_runtime::lock::Mutex, util::random_32, Result};
7
8const CHANNEL_BUFFER_SIZE: usize = 1000;
9
10pub type SubscriptionID = u32;
11
12/// A simple publish-subscribe system.
13// # Example
14///
15/// ```
16/// use karyon_core::pubsub::{Publisher};
17///
18///  async {
19///     let publisher = Publisher::new();
20///     
21///     let sub = publisher.subscribe().await;
22///     
23///     publisher.notify(&String::from("MESSAGE")).await;
24///
25///     let msg = sub.recv().await;
26///
27///     // ....
28///  };
29///  
30/// ```
31pub struct Publisher<T> {
32    subs: Mutex<HashMap<SubscriptionID, async_channel::Sender<T>>>,
33    subscription_buffer_size: usize,
34}
35
36impl<T: Clone> Publisher<T> {
37    /// Creates a new [`Publisher`]
38    pub fn new() -> Arc<Publisher<T>> {
39        Arc::new(Self {
40            subs: Mutex::new(HashMap::new()),
41            subscription_buffer_size: CHANNEL_BUFFER_SIZE,
42        })
43    }
44
45    /// Creates a new [`Publisher`] with the provided buffer size for the
46    /// [`Subscription`] channel.
47    ///
48    /// This is important to control the memory used by the [`Subscription`] channel.
49    /// If the subscriber can't keep up with the new messages coming, then the
50    /// channel buffer will fill with new messages, and if the buffer is full,
51    /// the emit function will block until the subscriber starts to process
52    /// the buffered messages.
53    ///
54    /// If `size` is zero, this function will panic.
55    pub fn with_buffer_size(size: usize) -> Arc<Publisher<T>> {
56        Arc::new(Self {
57            subs: Mutex::new(HashMap::new()),
58            subscription_buffer_size: size,
59        })
60    }
61
62    /// Subscribes and return a [`Subscription`]
63    pub async fn subscribe(self: &Arc<Self>) -> Subscription<T> {
64        let mut subs = self.subs.lock().await;
65
66        let chan = async_channel::bounded(self.subscription_buffer_size);
67
68        let mut sub_id = random_32();
69
70        // Generate a new one if sub_id already exists
71        while subs.contains_key(&sub_id) {
72            sub_id = random_32();
73        }
74
75        let sub = Subscription::new(sub_id, self.clone(), chan.1);
76        subs.insert(sub_id, chan.0);
77
78        sub
79    }
80
81    /// Unsubscribes by providing subscription id
82    pub async fn unsubscribe(self: &Arc<Self>, id: &SubscriptionID) {
83        self.subs.lock().await.remove(id);
84    }
85
86    /// Notifies all subscribers
87    pub async fn notify(self: &Arc<Self>, value: &T) {
88        let mut subs = self.subs.lock().await;
89
90        let mut results = FuturesUnordered::new();
91        let mut closed_subs = vec![];
92
93        for (sub_id, sub) in subs.iter() {
94            let result = async { (*sub_id, sub.send(value.clone()).await) };
95            results.push(result);
96        }
97
98        while let Some((id, fut_err)) = results.next().await {
99            if let Err(err) = fut_err {
100                error!("failed to notify {}: {}", id, err);
101                closed_subs.push(id);
102            }
103        }
104        drop(results);
105
106        for sub_id in closed_subs.iter() {
107            subs.remove(sub_id);
108        }
109    }
110}
111
112// Subscription
113pub struct Subscription<T> {
114    id: SubscriptionID,
115    recv_chan: async_channel::Receiver<T>,
116    publisher: Arc<Publisher<T>>,
117}
118
119impl<T: Clone> Subscription<T> {
120    /// Creates a new [`Subscription`]
121    pub fn new(
122        id: SubscriptionID,
123        publisher: Arc<Publisher<T>>,
124        recv_chan: async_channel::Receiver<T>,
125    ) -> Subscription<T> {
126        Self {
127            id,
128            recv_chan,
129            publisher,
130        }
131    }
132
133    /// Receive a message from the [`Publisher`]
134    pub async fn recv(&self) -> Result<T> {
135        let msg = self.recv_chan.recv().await?;
136        Ok(msg)
137    }
138
139    /// Unsubscribe from the [`Publisher`]
140    pub async fn unsubscribe(&self) {
141        self.publisher.unsubscribe(&self.id).await;
142    }
143}