karyon_core/async_util/
condvar.rs

1use std::{
2    collections::HashMap,
3    future::Future,
4    pin::Pin,
5    task::{Context, Poll, Waker},
6};
7
8use parking_lot::Mutex;
9
10use crate::{async_runtime::lock::MutexGuard, util::random_16};
11
12/// CondVar is an async version of <https://doc.rust-lang.org/std/sync/struct.Condvar.html>
13///
14/// # Example
15///
16///```
17/// use std::sync::Arc;
18///
19/// use karyon_core::async_util::CondVar;
20/// use karyon_core::async_runtime::{spawn, lock::Mutex};
21///
22///  async {
23///     
24///     let val = Arc::new(Mutex::new(false));
25///     let condvar = Arc::new(CondVar::new());
26///
27///     spawn({
28///         let val = val.clone();
29///         let condvar = condvar.clone();
30///         async move {
31///             let mut val = val.lock().await;
32///
33///             // While the boolean flag is false, wait for a signal.
34///             while !*val {
35///                 val = condvar.wait(val).await;
36///             }
37///
38///             // ...
39///         }
40///     });
41///
42///     spawn({
43///         let condvar = condvar.clone();
44///         async move {
45///             let mut val = val.lock().await;
46///
47///             // While the boolean flag is false, wait for a signal.
48///             while !*val {
49///                 val = condvar.wait(val).await;
50///             }
51///
52///             // ...
53///         }
54///     });
55///     
56///     // Wake up all waiting tasks on this condvar
57///     condvar.broadcast();
58///  };
59///
60/// ```
61pub struct CondVar {
62    inner: Mutex<Wakers>,
63}
64
65impl CondVar {
66    /// Creates a new CondVar
67    pub fn new() -> Self {
68        Self {
69            inner: Mutex::new(Wakers::new()),
70        }
71    }
72
73    /// Blocks the current task until this condition variable receives a notification.
74    pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
75        #[cfg(feature = "smol")]
76        let m = MutexGuard::source(&g);
77        #[cfg(feature = "tokio")]
78        let m = MutexGuard::mutex(&g);
79
80        CondVarAwait::new(self, g).await;
81
82        m.lock().await
83    }
84
85    /// Wakes up one blocked task waiting on this condvar.
86    pub fn signal(&self) {
87        self.inner.lock().wake(true);
88    }
89
90    /// Wakes up all blocked tasks waiting on this condvar.
91    pub fn broadcast(&self) {
92        self.inner.lock().wake(false);
93    }
94}
95
96impl Default for CondVar {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102struct CondVarAwait<'a, T> {
103    id: Option<u16>,
104    condvar: &'a CondVar,
105    guard: Option<MutexGuard<'a, T>>,
106}
107
108impl<'a, T> CondVarAwait<'a, T> {
109    fn new(condvar: &'a CondVar, guard: MutexGuard<'a, T>) -> Self {
110        Self {
111            condvar,
112            guard: Some(guard),
113            id: None,
114        }
115    }
116}
117
118impl<T> Future for CondVarAwait<'_, T> {
119    type Output = ();
120
121    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
122        let mut inner = self.condvar.inner.lock();
123
124        match self.guard.take() {
125            Some(_) => {
126                // the first pooll will release the Mutexguard
127                self.id = Some(inner.put(Some(cx.waker().clone())));
128                Poll::Pending
129            }
130            None => {
131                // Return Ready if it has already been polled and removed
132                // from the waker list.
133                if self.id.is_none() {
134                    return Poll::Ready(());
135                }
136
137                let i = self.id.as_ref().unwrap();
138                match inner.wakers.get_mut(i).unwrap() {
139                    Some(wk) => {
140                        // This will prevent cloning again
141                        if !wk.will_wake(cx.waker()) {
142                            wk.clone_from(cx.waker());
143                        }
144                        Poll::Pending
145                    }
146                    None => {
147                        inner.delete(i);
148                        self.id = None;
149                        Poll::Ready(())
150                    }
151                }
152            }
153        }
154    }
155}
156
157impl<T> Drop for CondVarAwait<'_, T> {
158    fn drop(&mut self) {
159        if let Some(id) = self.id {
160            let mut inner = self.condvar.inner.lock();
161            if let Some(wk) = inner.wakers.get_mut(&id).unwrap().take() {
162                wk.wake()
163            }
164        }
165    }
166}
167
168/// Wakers is a helper struct to store the task wakers
169struct Wakers {
170    wakers: HashMap<u16, Option<Waker>>,
171}
172
173impl Wakers {
174    fn new() -> Self {
175        Self {
176            wakers: HashMap::new(),
177        }
178    }
179
180    fn put(&mut self, waker: Option<Waker>) -> u16 {
181        let mut id: u16;
182
183        id = random_16();
184        while self.wakers.contains_key(&id) {
185            id = random_16();
186        }
187
188        self.wakers.insert(id, waker);
189        id
190    }
191
192    fn delete(&mut self, id: &u16) -> Option<Option<Waker>> {
193        self.wakers.remove(id)
194    }
195
196    fn wake(&mut self, signal: bool) {
197        for (_, wk) in self.wakers.iter_mut() {
198            match wk.take() {
199                Some(w) => {
200                    w.wake();
201                    if signal {
202                        break;
203                    }
204                }
205                None => continue,
206            }
207        }
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use std::{
214        collections::VecDeque,
215        sync::{
216            atomic::{AtomicUsize, Ordering},
217            Arc,
218        },
219    };
220
221    use crate::async_runtime::{block_on, lock::Mutex, spawn};
222
223    use super::*;
224
225    // The tests below demonstrate a solution to a problem in the Wikipedia
226    // explanation of condition variables:
227    // https://en.wikipedia.org/wiki/Monitor_(synchronization)#Solving_the_bounded_producer/consumer_problem.
228
229    struct Queue {
230        items: VecDeque<String>,
231        max_len: usize,
232    }
233    impl Queue {
234        fn new(max_len: usize) -> Self {
235            Self {
236                items: VecDeque::new(),
237                max_len,
238            }
239        }
240
241        fn is_full(&self) -> bool {
242            self.items.len() == self.max_len
243        }
244
245        fn is_empty(&self) -> bool {
246            self.items.is_empty()
247        }
248    }
249
250    #[test]
251    fn test_condvar_signal() {
252        block_on(async {
253            let number_of_tasks = 30;
254
255            let queue = Arc::new(Mutex::new(Queue::new(5)));
256            let condvar_full = Arc::new(CondVar::new());
257            let condvar_empty = Arc::new(CondVar::new());
258
259            let _producer1 = spawn({
260                let queue = queue.clone();
261                let condvar_full = condvar_full.clone();
262                let condvar_empty = condvar_empty.clone();
263                async move {
264                    for i in 1..number_of_tasks {
265                        // Lock queue mtuex
266                        let mut queue = queue.lock().await;
267
268                        // Check if the queue is non-full
269                        while queue.is_full() {
270                            // Release queue mutex and sleep
271                            queue = condvar_full.wait(queue).await;
272                        }
273
274                        queue.items.push_back(format!("task {i}"));
275
276                        // Wake up the consumer
277                        condvar_empty.signal();
278                    }
279                }
280            });
281
282            let task_consumed = Arc::new(AtomicUsize::new(0));
283
284            let consumer = spawn({
285                let queue = queue.clone();
286                let task_consumed = task_consumed.clone();
287                async move {
288                    for _ in 1..number_of_tasks {
289                        // Lock queue mtuex
290                        let mut queue = queue.lock().await;
291
292                        // Check if the queue is non-empty
293                        while queue.is_empty() {
294                            // Release queue mutex and sleep
295                            queue = condvar_empty.wait(queue).await;
296                        }
297
298                        let _ = queue.items.pop_front().unwrap();
299
300                        task_consumed.fetch_add(1, Ordering::Relaxed);
301
302                        // Do something
303
304                        // Wake up the producer
305                        condvar_full.signal();
306                    }
307                }
308            });
309
310            let _ = consumer.await;
311            assert!(queue.lock().await.is_empty());
312            assert_eq!(task_consumed.load(Ordering::Relaxed), 29);
313        });
314    }
315
316    #[test]
317    fn test_condvar_broadcast() {
318        block_on(async {
319            let tasks = 30;
320
321            let queue = Arc::new(Mutex::new(Queue::new(5)));
322            let condvar = Arc::new(CondVar::new());
323
324            let _producer1 = spawn({
325                let queue = queue.clone();
326                let condvar = condvar.clone();
327                async move {
328                    for i in 1..tasks {
329                        // Lock queue mtuex
330                        let mut queue = queue.lock().await;
331
332                        // Check if the queue is non-full
333                        while queue.is_full() {
334                            // Release queue mutex and sleep
335                            queue = condvar.wait(queue).await;
336                        }
337
338                        queue.items.push_back(format!("producer1: task {i}"));
339
340                        // Wake up all producer and consumer tasks
341                        condvar.broadcast();
342                    }
343                }
344            });
345
346            let _producer2 = spawn({
347                let queue = queue.clone();
348                let condvar = condvar.clone();
349                async move {
350                    for i in 1..tasks {
351                        // Lock queue mtuex
352                        let mut queue = queue.lock().await;
353
354                        // Check if the queue is non-full
355                        while queue.is_full() {
356                            // Release queue mutex and sleep
357                            queue = condvar.wait(queue).await;
358                        }
359
360                        queue.items.push_back(format!("producer2: task {i}"));
361
362                        // Wake up all producer and consumer tasks
363                        condvar.broadcast();
364                    }
365                }
366            });
367
368            let task_consumed = Arc::new(AtomicUsize::new(0));
369
370            let consumer = spawn({
371                let queue = queue.clone();
372                let task_consumed = task_consumed.clone();
373                async move {
374                    for _ in 1..((tasks * 2) - 1) {
375                        {
376                            // Lock queue mutex
377                            let mut queue = queue.lock().await;
378
379                            // Check if the queue is non-empty
380                            while queue.is_empty() {
381                                // Release queue mutex and sleep
382                                queue = condvar.wait(queue).await;
383                            }
384
385                            let _ = queue.items.pop_front().unwrap();
386
387                            task_consumed.fetch_add(1, Ordering::Relaxed);
388
389                            // Do something
390
391                            // Wake up all producer and consumer tasks
392                            condvar.broadcast();
393                        }
394                    }
395                }
396            });
397
398            let _ = consumer.await;
399            assert!(queue.lock().await.is_empty());
400            assert_eq!(task_consumed.load(Ordering::Relaxed), 58);
401        });
402    }
403}