karyon_core/async_util/
select.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use pin_project_lite::pin_project;
6
7/// Returns the result of the future that completes first, preferring future1
8/// if both are ready.
9///
10/// # Examples
11///
12/// ```
13/// use std::future;
14///
15/// use karyon_core::async_util::{select, Either};
16///
17///  async {
18///     let fut1 = future::pending::<String>();
19///     let fut2 = future::ready(0);
20///     let res = select(fut1, fut2).await;
21///     assert!(matches!(res, Either::Right(0)));
22///     // ....
23///  };
24///
25/// ```
26///
27pub fn select<T1, T2, F1, F2>(future1: F1, future2: F2) -> Select<F1, F2>
28where
29    F1: Future<Output = T1>,
30    F2: Future<Output = T2>,
31{
32    Select { future1, future2 }
33}
34
35pin_project! {
36    #[derive(Debug)]
37    pub struct Select<F1, F2> {
38        #[pin]
39        future1: F1,
40        #[pin]
41        future2: F2,
42    }
43}
44
45/// The return value from the [`select`] function, indicating which future
46/// completed first.
47#[derive(Debug)]
48pub enum Either<T1, T2> {
49    Left(T1),
50    Right(T2),
51}
52
53// Implement the Future trait for the Select struct.
54impl<T1, T2, F1, F2> Future for Select<F1, F2>
55where
56    F1: Future<Output = T1>,
57    F2: Future<Output = T2>,
58{
59    type Output = Either<T1, T2>;
60
61    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
62        let this = self.project();
63
64        if let Poll::Ready(t) = this.future1.poll(cx) {
65            return Poll::Ready(Either::Left(t));
66        }
67
68        if let Poll::Ready(t) = this.future2.poll(cx) {
69            return Poll::Ready(Either::Right(t));
70        }
71
72        Poll::Pending
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use std::future;
79
80    use crate::{async_runtime::block_on, async_util::sleep};
81
82    use super::{select, Either};
83
84    #[test]
85    fn test_async_select() {
86        block_on(async move {
87            let fut = select(sleep(std::time::Duration::MAX), future::ready(0 as u32)).await;
88            assert!(matches!(fut, Either::Right(0)));
89
90            let fut1 = future::pending::<String>();
91            let fut2 = future::ready(0);
92            let res = select(fut1, fut2).await;
93            assert!(matches!(res, Either::Right(0)));
94
95            let fut1 = future::ready(0);
96            let fut2 = future::pending::<String>();
97            let res = select(fut1, fut2).await;
98            assert!(matches!(res, Either::Left(_)));
99        });
100    }
101}