karyon_core/async_util/
condvar.rs1use std::{
2 collections::{hash_map::Entry, HashMap},
3 future::Future,
4 pin::Pin,
5 task::{Context, Poll, Waker},
6};
7
8use parking_lot::Mutex;
9
10use crate::{async_runtime::lock::MutexGuard, util::random_16};
11
12pub struct CondVar {
62 inner: Mutex<Wakers>,
63}
64
65impl CondVar {
66 pub fn new() -> Self {
68 Self {
69 inner: Mutex::new(Wakers::new()),
70 }
71 }
72
73 pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
75 #[cfg(feature = "smol")]
76 let m = MutexGuard::source(&g);
77 #[cfg(feature = "tokio")]
78 let m = MutexGuard::mutex(&g);
79
80 CondVarAwait::new(self, g).await;
81
82 m.lock().await
83 }
84
85 pub fn signal(&self) {
87 self.inner.lock().wake(true);
88 }
89
90 pub fn broadcast(&self) {
92 self.inner.lock().wake(false);
93 }
94}
95
96impl Default for CondVar {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102struct CondVarAwait<'a, T> {
103 id: Option<u16>,
104 condvar: &'a CondVar,
105 guard: Option<MutexGuard<'a, T>>,
106}
107
108impl<'a, T> CondVarAwait<'a, T> {
109 fn new(condvar: &'a CondVar, guard: MutexGuard<'a, T>) -> Self {
110 Self {
111 condvar,
112 guard: Some(guard),
113 id: None,
114 }
115 }
116}
117
118impl<T> Future for CondVarAwait<'_, T> {
119 type Output = ();
120
121 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
122 let mut inner = self.condvar.inner.lock();
123
124 match self.guard.take() {
125 Some(_) => {
126 self.id = Some(inner.put(Some(cx.waker().clone())));
128 Poll::Pending
129 }
130 None => {
131 if self.id.is_none() {
134 return Poll::Ready(());
135 }
136
137 let i = self.id.as_ref().unwrap();
138 let waker_op = match inner.wakers.get_mut(i) {
139 Some(wk) => wk,
140 None => {
141 self.id = None;
142 return Poll::Ready(());
143 }
144 };
145
146 match waker_op {
147 Some(wk) => {
148 if !wk.will_wake(cx.waker()) {
150 wk.clone_from(cx.waker());
151 }
152 Poll::Pending
153 }
154 None => {
155 inner.delete(i);
156 self.id = None;
157 Poll::Ready(())
158 }
159 }
160 }
161 }
162 }
163}
164
165impl<T> Drop for CondVarAwait<'_, T> {
166 fn drop(&mut self) {
167 if let Some(id) = self.id {
168 let mut inner = self.condvar.inner.lock();
169 if let Some(wk) = inner.wakers.remove(&id).flatten() {
170 wk.wake()
171 }
172 }
173 }
174}
175
176struct Wakers {
178 wakers: HashMap<u16, Option<Waker>>,
179}
180
181impl Wakers {
182 fn new() -> Self {
183 Self {
184 wakers: HashMap::new(),
185 }
186 }
187
188 fn put(&mut self, waker: Option<Waker>) -> u16 {
189 const MAX_RETRIES: u8 = 100;
190 let mut id: u16;
191
192 for _ in 0..MAX_RETRIES {
193 id = random_16();
194 if let Entry::Vacant(e) = self.wakers.entry(id) {
195 e.insert(waker);
196 return id;
197 }
198 }
199
200 panic!("Wakers: All IDs exhausted");
201 }
202
203 fn delete(&mut self, id: &u16) -> Option<Option<Waker>> {
204 self.wakers.remove(id)
205 }
206
207 fn wake(&mut self, signal: bool) {
208 for (_, wk) in self.wakers.iter_mut() {
209 match wk.take() {
210 Some(w) => {
211 w.wake();
212 if signal {
213 break;
214 }
215 }
216 None => continue,
217 }
218 }
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use std::{
225 collections::VecDeque,
226 sync::{
227 atomic::{AtomicUsize, Ordering},
228 Arc,
229 },
230 };
231
232 use crate::async_runtime::{block_on, lock::Mutex, spawn};
233
234 use super::*;
235
236 struct Queue {
241 items: VecDeque<String>,
242 max_len: usize,
243 }
244 impl Queue {
245 fn new(max_len: usize) -> Self {
246 Self {
247 items: VecDeque::new(),
248 max_len,
249 }
250 }
251
252 fn is_full(&self) -> bool {
253 self.items.len() == self.max_len
254 }
255
256 fn is_empty(&self) -> bool {
257 self.items.is_empty()
258 }
259 }
260
261 #[test]
262 fn test_condvar_signal() {
263 block_on(async {
264 let number_of_tasks = 30;
265
266 let queue = Arc::new(Mutex::new(Queue::new(5)));
267 let condvar_full = Arc::new(CondVar::new());
268 let condvar_empty = Arc::new(CondVar::new());
269
270 let _producer1 = spawn({
271 let queue = queue.clone();
272 let condvar_full = condvar_full.clone();
273 let condvar_empty = condvar_empty.clone();
274 async move {
275 for i in 1..number_of_tasks {
276 let mut queue = queue.lock().await;
278
279 while queue.is_full() {
281 queue = condvar_full.wait(queue).await;
283 }
284
285 queue.items.push_back(format!("task {i}"));
286
287 condvar_empty.signal();
289 }
290 }
291 });
292
293 let task_consumed = Arc::new(AtomicUsize::new(0));
294
295 let consumer = spawn({
296 let queue = queue.clone();
297 let task_consumed = task_consumed.clone();
298 async move {
299 for _ in 1..number_of_tasks {
300 let mut queue = queue.lock().await;
302
303 while queue.is_empty() {
305 queue = condvar_empty.wait(queue).await;
307 }
308
309 let _ = queue.items.pop_front().unwrap();
310
311 task_consumed.fetch_add(1, Ordering::Relaxed);
312
313 condvar_full.signal();
317 }
318 }
319 });
320
321 let _ = consumer.await;
322 assert!(queue.lock().await.is_empty());
323 assert_eq!(task_consumed.load(Ordering::Relaxed), 29);
324 });
325 }
326
327 #[test]
328 fn test_condvar_broadcast() {
329 block_on(async {
330 let tasks = 30;
331
332 let queue = Arc::new(Mutex::new(Queue::new(5)));
333 let condvar = Arc::new(CondVar::new());
334
335 let _producer1 = spawn({
336 let queue = queue.clone();
337 let condvar = condvar.clone();
338 async move {
339 for i in 1..tasks {
340 let mut queue = queue.lock().await;
342
343 while queue.is_full() {
345 queue = condvar.wait(queue).await;
347 }
348
349 queue.items.push_back(format!("producer1: task {i}"));
350
351 condvar.broadcast();
353 }
354 }
355 });
356
357 let _producer2 = spawn({
358 let queue = queue.clone();
359 let condvar = condvar.clone();
360 async move {
361 for i in 1..tasks {
362 let mut queue = queue.lock().await;
364
365 while queue.is_full() {
367 queue = condvar.wait(queue).await;
369 }
370
371 queue.items.push_back(format!("producer2: task {i}"));
372
373 condvar.broadcast();
375 }
376 }
377 });
378
379 let task_consumed = Arc::new(AtomicUsize::new(0));
380
381 let consumer = spawn({
382 let queue = queue.clone();
383 let task_consumed = task_consumed.clone();
384 async move {
385 for _ in 1..((tasks * 2) - 1) {
386 {
387 let mut queue = queue.lock().await;
389
390 while queue.is_empty() {
392 queue = condvar.wait(queue).await;
394 }
395
396 let _ = queue.items.pop_front().unwrap();
397
398 task_consumed.fetch_add(1, Ordering::Relaxed);
399
400 condvar.broadcast();
404 }
405 }
406 }
407 });
408
409 let _ = consumer.await;
410 assert!(queue.lock().await.is_empty());
411 assert_eq!(task_consumed.load(Ordering::Relaxed), 58);
412 });
413 }
414}