karyon_core/async_util/
condvar.rs

1use std::{
2    collections::{hash_map::Entry, HashMap},
3    future::Future,
4    pin::Pin,
5    task::{Context, Poll, Waker},
6};
7
8use parking_lot::Mutex;
9
10use crate::{async_runtime::lock::MutexGuard, util::random_16};
11
12/// CondVar is an async version of <https://doc.rust-lang.org/std/sync/struct.Condvar.html>
13///
14/// # Example
15///
16///```
17/// use std::sync::Arc;
18///
19/// use karyon_core::async_util::CondVar;
20/// use karyon_core::async_runtime::{spawn, lock::Mutex};
21///
22///  async {
23///     
24///     let val = Arc::new(Mutex::new(false));
25///     let condvar = Arc::new(CondVar::new());
26///
27///     spawn({
28///         let val = val.clone();
29///         let condvar = condvar.clone();
30///         async move {
31///             let mut val = val.lock().await;
32///
33///             // While the boolean flag is false, wait for a signal.
34///             while !*val {
35///                 val = condvar.wait(val).await;
36///             }
37///
38///             // ...
39///         }
40///     });
41///
42///     spawn({
43///         let condvar = condvar.clone();
44///         async move {
45///             let mut val = val.lock().await;
46///
47///             // While the boolean flag is false, wait for a signal.
48///             while !*val {
49///                 val = condvar.wait(val).await;
50///             }
51///
52///             // ...
53///         }
54///     });
55///     
56///     // Wake up all waiting tasks on this condvar
57///     condvar.broadcast();
58///  };
59///
60/// ```
61pub struct CondVar {
62    inner: Mutex<Wakers>,
63}
64
65impl CondVar {
66    /// Creates a new CondVar
67    pub fn new() -> Self {
68        Self {
69            inner: Mutex::new(Wakers::new()),
70        }
71    }
72
73    /// Blocks the current task until this condition variable receives a notification.
74    pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
75        #[cfg(feature = "smol")]
76        let m = MutexGuard::source(&g);
77        #[cfg(feature = "tokio")]
78        let m = MutexGuard::mutex(&g);
79
80        CondVarAwait::new(self, g).await;
81
82        m.lock().await
83    }
84
85    /// Wakes up one blocked task waiting on this condvar.
86    pub fn signal(&self) {
87        self.inner.lock().wake(true);
88    }
89
90    /// Wakes up all blocked tasks waiting on this condvar.
91    pub fn broadcast(&self) {
92        self.inner.lock().wake(false);
93    }
94}
95
96impl Default for CondVar {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102struct CondVarAwait<'a, T> {
103    id: Option<u16>,
104    condvar: &'a CondVar,
105    guard: Option<MutexGuard<'a, T>>,
106}
107
108impl<'a, T> CondVarAwait<'a, T> {
109    fn new(condvar: &'a CondVar, guard: MutexGuard<'a, T>) -> Self {
110        Self {
111            condvar,
112            guard: Some(guard),
113            id: None,
114        }
115    }
116}
117
118impl<T> Future for CondVarAwait<'_, T> {
119    type Output = ();
120
121    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
122        let mut inner = self.condvar.inner.lock();
123
124        match self.guard.take() {
125            Some(_) => {
126                // the first pooll will release the Mutexguard
127                self.id = Some(inner.put(Some(cx.waker().clone())));
128                Poll::Pending
129            }
130            None => {
131                // Return Ready if it has already been polled and removed
132                // from the waker list.
133                if self.id.is_none() {
134                    return Poll::Ready(());
135                }
136
137                let i = self.id.as_ref().unwrap();
138                let waker_op = match inner.wakers.get_mut(i) {
139                    Some(wk) => wk,
140                    None => {
141                        self.id = None;
142                        return Poll::Ready(());
143                    }
144                };
145
146                match waker_op {
147                    Some(wk) => {
148                        // This will prevent cloning again
149                        if !wk.will_wake(cx.waker()) {
150                            wk.clone_from(cx.waker());
151                        }
152                        Poll::Pending
153                    }
154                    None => {
155                        inner.delete(i);
156                        self.id = None;
157                        Poll::Ready(())
158                    }
159                }
160            }
161        }
162    }
163}
164
165impl<T> Drop for CondVarAwait<'_, T> {
166    fn drop(&mut self) {
167        if let Some(id) = self.id {
168            let mut inner = self.condvar.inner.lock();
169            if let Some(wk) = inner.wakers.remove(&id).flatten() {
170                wk.wake()
171            }
172        }
173    }
174}
175
176/// Wakers is a helper struct to store the task wakers
177struct Wakers {
178    wakers: HashMap<u16, Option<Waker>>,
179}
180
181impl Wakers {
182    fn new() -> Self {
183        Self {
184            wakers: HashMap::new(),
185        }
186    }
187
188    fn put(&mut self, waker: Option<Waker>) -> u16 {
189        const MAX_RETRIES: u8 = 100;
190        let mut id: u16;
191
192        for _ in 0..MAX_RETRIES {
193            id = random_16();
194            if let Entry::Vacant(e) = self.wakers.entry(id) {
195                e.insert(waker);
196                return id;
197            }
198        }
199
200        panic!("Wakers: All IDs exhausted");
201    }
202
203    fn delete(&mut self, id: &u16) -> Option<Option<Waker>> {
204        self.wakers.remove(id)
205    }
206
207    fn wake(&mut self, signal: bool) {
208        for (_, wk) in self.wakers.iter_mut() {
209            match wk.take() {
210                Some(w) => {
211                    w.wake();
212                    if signal {
213                        break;
214                    }
215                }
216                None => continue,
217            }
218        }
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use std::{
225        collections::VecDeque,
226        sync::{
227            atomic::{AtomicUsize, Ordering},
228            Arc,
229        },
230    };
231
232    use crate::async_runtime::{block_on, lock::Mutex, spawn};
233
234    use super::*;
235
236    // The tests below demonstrate a solution to a problem in the Wikipedia
237    // explanation of condition variables:
238    // https://en.wikipedia.org/wiki/Monitor_(synchronization)#Solving_the_bounded_producer/consumer_problem.
239
240    struct Queue {
241        items: VecDeque<String>,
242        max_len: usize,
243    }
244    impl Queue {
245        fn new(max_len: usize) -> Self {
246            Self {
247                items: VecDeque::new(),
248                max_len,
249            }
250        }
251
252        fn is_full(&self) -> bool {
253            self.items.len() == self.max_len
254        }
255
256        fn is_empty(&self) -> bool {
257            self.items.is_empty()
258        }
259    }
260
261    #[test]
262    fn test_condvar_signal() {
263        block_on(async {
264            let number_of_tasks = 30;
265
266            let queue = Arc::new(Mutex::new(Queue::new(5)));
267            let condvar_full = Arc::new(CondVar::new());
268            let condvar_empty = Arc::new(CondVar::new());
269
270            let _producer1 = spawn({
271                let queue = queue.clone();
272                let condvar_full = condvar_full.clone();
273                let condvar_empty = condvar_empty.clone();
274                async move {
275                    for i in 1..number_of_tasks {
276                        // Lock queue mtuex
277                        let mut queue = queue.lock().await;
278
279                        // Check if the queue is non-full
280                        while queue.is_full() {
281                            // Release queue mutex and sleep
282                            queue = condvar_full.wait(queue).await;
283                        }
284
285                        queue.items.push_back(format!("task {i}"));
286
287                        // Wake up the consumer
288                        condvar_empty.signal();
289                    }
290                }
291            });
292
293            let task_consumed = Arc::new(AtomicUsize::new(0));
294
295            let consumer = spawn({
296                let queue = queue.clone();
297                let task_consumed = task_consumed.clone();
298                async move {
299                    for _ in 1..number_of_tasks {
300                        // Lock queue mtuex
301                        let mut queue = queue.lock().await;
302
303                        // Check if the queue is non-empty
304                        while queue.is_empty() {
305                            // Release queue mutex and sleep
306                            queue = condvar_empty.wait(queue).await;
307                        }
308
309                        let _ = queue.items.pop_front().unwrap();
310
311                        task_consumed.fetch_add(1, Ordering::Relaxed);
312
313                        // Do something
314
315                        // Wake up the producer
316                        condvar_full.signal();
317                    }
318                }
319            });
320
321            let _ = consumer.await;
322            assert!(queue.lock().await.is_empty());
323            assert_eq!(task_consumed.load(Ordering::Relaxed), 29);
324        });
325    }
326
327    #[test]
328    fn test_condvar_broadcast() {
329        block_on(async {
330            let tasks = 30;
331
332            let queue = Arc::new(Mutex::new(Queue::new(5)));
333            let condvar = Arc::new(CondVar::new());
334
335            let _producer1 = spawn({
336                let queue = queue.clone();
337                let condvar = condvar.clone();
338                async move {
339                    for i in 1..tasks {
340                        // Lock queue mtuex
341                        let mut queue = queue.lock().await;
342
343                        // Check if the queue is non-full
344                        while queue.is_full() {
345                            // Release queue mutex and sleep
346                            queue = condvar.wait(queue).await;
347                        }
348
349                        queue.items.push_back(format!("producer1: task {i}"));
350
351                        // Wake up all producer and consumer tasks
352                        condvar.broadcast();
353                    }
354                }
355            });
356
357            let _producer2 = spawn({
358                let queue = queue.clone();
359                let condvar = condvar.clone();
360                async move {
361                    for i in 1..tasks {
362                        // Lock queue mtuex
363                        let mut queue = queue.lock().await;
364
365                        // Check if the queue is non-full
366                        while queue.is_full() {
367                            // Release queue mutex and sleep
368                            queue = condvar.wait(queue).await;
369                        }
370
371                        queue.items.push_back(format!("producer2: task {i}"));
372
373                        // Wake up all producer and consumer tasks
374                        condvar.broadcast();
375                    }
376                }
377            });
378
379            let task_consumed = Arc::new(AtomicUsize::new(0));
380
381            let consumer = spawn({
382                let queue = queue.clone();
383                let task_consumed = task_consumed.clone();
384                async move {
385                    for _ in 1..((tasks * 2) - 1) {
386                        {
387                            // Lock queue mutex
388                            let mut queue = queue.lock().await;
389
390                            // Check if the queue is non-empty
391                            while queue.is_empty() {
392                                // Release queue mutex and sleep
393                                queue = condvar.wait(queue).await;
394                            }
395
396                            let _ = queue.items.pop_front().unwrap();
397
398                            task_consumed.fetch_add(1, Ordering::Relaxed);
399
400                            // Do something
401
402                            // Wake up all producer and consumer tasks
403                            condvar.broadcast();
404                        }
405                    }
406                }
407            });
408
409            let _ = consumer.await;
410            assert!(queue.lock().await.is_empty());
411            assert_eq!(task_consumed.load(Ordering::Relaxed), 58);
412        });
413    }
414}