karyon_core/async_util/
condvar.rs1use std::{
2 collections::HashMap,
3 future::Future,
4 pin::Pin,
5 task::{Context, Poll, Waker},
6};
7
8use parking_lot::Mutex;
9
10use crate::{async_runtime::lock::MutexGuard, util::random_16};
11
12pub struct CondVar {
62 inner: Mutex<Wakers>,
63}
64
65impl CondVar {
66 pub fn new() -> Self {
68 Self {
69 inner: Mutex::new(Wakers::new()),
70 }
71 }
72
73 pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
75 #[cfg(feature = "smol")]
76 let m = MutexGuard::source(&g);
77 #[cfg(feature = "tokio")]
78 let m = MutexGuard::mutex(&g);
79
80 CondVarAwait::new(self, g).await;
81
82 m.lock().await
83 }
84
85 pub fn signal(&self) {
87 self.inner.lock().wake(true);
88 }
89
90 pub fn broadcast(&self) {
92 self.inner.lock().wake(false);
93 }
94}
95
96impl Default for CondVar {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102struct CondVarAwait<'a, T> {
103 id: Option<u16>,
104 condvar: &'a CondVar,
105 guard: Option<MutexGuard<'a, T>>,
106}
107
108impl<'a, T> CondVarAwait<'a, T> {
109 fn new(condvar: &'a CondVar, guard: MutexGuard<'a, T>) -> Self {
110 Self {
111 condvar,
112 guard: Some(guard),
113 id: None,
114 }
115 }
116}
117
118impl<T> Future for CondVarAwait<'_, T> {
119 type Output = ();
120
121 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
122 let mut inner = self.condvar.inner.lock();
123
124 match self.guard.take() {
125 Some(_) => {
126 self.id = Some(inner.put(Some(cx.waker().clone())));
128 Poll::Pending
129 }
130 None => {
131 if self.id.is_none() {
134 return Poll::Ready(());
135 }
136
137 let i = self.id.as_ref().unwrap();
138 match inner.wakers.get_mut(i).unwrap() {
139 Some(wk) => {
140 if !wk.will_wake(cx.waker()) {
142 wk.clone_from(cx.waker());
143 }
144 Poll::Pending
145 }
146 None => {
147 inner.delete(i);
148 self.id = None;
149 Poll::Ready(())
150 }
151 }
152 }
153 }
154 }
155}
156
157impl<T> Drop for CondVarAwait<'_, T> {
158 fn drop(&mut self) {
159 if let Some(id) = self.id {
160 let mut inner = self.condvar.inner.lock();
161 if let Some(wk) = inner.wakers.get_mut(&id).unwrap().take() {
162 wk.wake()
163 }
164 }
165 }
166}
167
168struct Wakers {
170 wakers: HashMap<u16, Option<Waker>>,
171}
172
173impl Wakers {
174 fn new() -> Self {
175 Self {
176 wakers: HashMap::new(),
177 }
178 }
179
180 fn put(&mut self, waker: Option<Waker>) -> u16 {
181 let mut id: u16;
182
183 id = random_16();
184 while self.wakers.contains_key(&id) {
185 id = random_16();
186 }
187
188 self.wakers.insert(id, waker);
189 id
190 }
191
192 fn delete(&mut self, id: &u16) -> Option<Option<Waker>> {
193 self.wakers.remove(id)
194 }
195
196 fn wake(&mut self, signal: bool) {
197 for (_, wk) in self.wakers.iter_mut() {
198 match wk.take() {
199 Some(w) => {
200 w.wake();
201 if signal {
202 break;
203 }
204 }
205 None => continue,
206 }
207 }
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use std::{
214 collections::VecDeque,
215 sync::{
216 atomic::{AtomicUsize, Ordering},
217 Arc,
218 },
219 };
220
221 use crate::async_runtime::{block_on, lock::Mutex, spawn};
222
223 use super::*;
224
225 struct Queue {
230 items: VecDeque<String>,
231 max_len: usize,
232 }
233 impl Queue {
234 fn new(max_len: usize) -> Self {
235 Self {
236 items: VecDeque::new(),
237 max_len,
238 }
239 }
240
241 fn is_full(&self) -> bool {
242 self.items.len() == self.max_len
243 }
244
245 fn is_empty(&self) -> bool {
246 self.items.is_empty()
247 }
248 }
249
250 #[test]
251 fn test_condvar_signal() {
252 block_on(async {
253 let number_of_tasks = 30;
254
255 let queue = Arc::new(Mutex::new(Queue::new(5)));
256 let condvar_full = Arc::new(CondVar::new());
257 let condvar_empty = Arc::new(CondVar::new());
258
259 let _producer1 = spawn({
260 let queue = queue.clone();
261 let condvar_full = condvar_full.clone();
262 let condvar_empty = condvar_empty.clone();
263 async move {
264 for i in 1..number_of_tasks {
265 let mut queue = queue.lock().await;
267
268 while queue.is_full() {
270 queue = condvar_full.wait(queue).await;
272 }
273
274 queue.items.push_back(format!("task {i}"));
275
276 condvar_empty.signal();
278 }
279 }
280 });
281
282 let task_consumed = Arc::new(AtomicUsize::new(0));
283
284 let consumer = spawn({
285 let queue = queue.clone();
286 let task_consumed = task_consumed.clone();
287 async move {
288 for _ in 1..number_of_tasks {
289 let mut queue = queue.lock().await;
291
292 while queue.is_empty() {
294 queue = condvar_empty.wait(queue).await;
296 }
297
298 let _ = queue.items.pop_front().unwrap();
299
300 task_consumed.fetch_add(1, Ordering::Relaxed);
301
302 condvar_full.signal();
306 }
307 }
308 });
309
310 let _ = consumer.await;
311 assert!(queue.lock().await.is_empty());
312 assert_eq!(task_consumed.load(Ordering::Relaxed), 29);
313 });
314 }
315
316 #[test]
317 fn test_condvar_broadcast() {
318 block_on(async {
319 let tasks = 30;
320
321 let queue = Arc::new(Mutex::new(Queue::new(5)));
322 let condvar = Arc::new(CondVar::new());
323
324 let _producer1 = spawn({
325 let queue = queue.clone();
326 let condvar = condvar.clone();
327 async move {
328 for i in 1..tasks {
329 let mut queue = queue.lock().await;
331
332 while queue.is_full() {
334 queue = condvar.wait(queue).await;
336 }
337
338 queue.items.push_back(format!("producer1: task {i}"));
339
340 condvar.broadcast();
342 }
343 }
344 });
345
346 let _producer2 = spawn({
347 let queue = queue.clone();
348 let condvar = condvar.clone();
349 async move {
350 for i in 1..tasks {
351 let mut queue = queue.lock().await;
353
354 while queue.is_full() {
356 queue = condvar.wait(queue).await;
358 }
359
360 queue.items.push_back(format!("producer2: task {i}"));
361
362 condvar.broadcast();
364 }
365 }
366 });
367
368 let task_consumed = Arc::new(AtomicUsize::new(0));
369
370 let consumer = spawn({
371 let queue = queue.clone();
372 let task_consumed = task_consumed.clone();
373 async move {
374 for _ in 1..((tasks * 2) - 1) {
375 {
376 let mut queue = queue.lock().await;
378
379 while queue.is_empty() {
381 queue = condvar.wait(queue).await;
383 }
384
385 let _ = queue.items.pop_front().unwrap();
386
387 task_consumed.fetch_add(1, Ordering::Relaxed);
388
389 condvar.broadcast();
393 }
394 }
395 }
396 });
397
398 let _ = consumer.await;
399 assert!(queue.lock().await.is_empty());
400 assert_eq!(task_consumed.load(Ordering::Relaxed), 58);
401 });
402 }
403}