karyon_core/async_util/
task_group.rs

1use std::{future::Future, sync::Arc};
2
3use parking_lot::Mutex;
4
5use crate::async_runtime::{global_executor, Executor, Task};
6
7use super::{select, CondWait, Either};
8
9/// TaskGroup A group that contains spawned tasks.
10///
11/// # Example
12///
13/// ```
14///
15/// use std::sync::Arc;
16///
17/// use karyon_core::async_util::{TaskGroup, sleep};
18///
19/// async {
20///     let group = TaskGroup::new();
21///
22///     group.spawn(sleep(std::time::Duration::MAX), |_| async {});
23///
24///     group.cancel().await;
25///
26/// };
27///
28/// ```
29pub struct TaskGroup {
30    tasks: Mutex<Vec<TaskHandler>>,
31    stop_signal: Arc<CondWait>,
32    executor: Executor,
33}
34
35impl TaskGroup {
36    /// Creates a new TaskGroup without providing an executor
37    ///
38    /// This will spawn a task onto a global executor (single-threaded by default).
39    pub fn new() -> Self {
40        Self {
41            tasks: Mutex::new(Vec::new()),
42            stop_signal: Arc::new(CondWait::new()),
43            executor: global_executor(),
44        }
45    }
46
47    /// Creates a new TaskGroup by providing an executor
48    pub fn with_executor(executor: Executor) -> Self {
49        Self {
50            tasks: Mutex::new(Vec::new()),
51            stop_signal: Arc::new(CondWait::new()),
52            executor,
53        }
54    }
55
56    /// Spawns a new task and calls the callback after it has completed
57    /// or been canceled. The callback will have the `TaskResult` as a
58    /// parameter, indicating whether the task completed or was canceled.
59    pub fn spawn<T, Fut, CallbackF, CallbackFut>(&self, fut: Fut, callback: CallbackF)
60    where
61        T: Send + Sync + 'static,
62        Fut: Future<Output = T> + Send + 'static,
63        CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'static,
64        CallbackFut: Future<Output = ()> + Send + 'static,
65    {
66        let task = TaskHandler::new(
67            self.executor.clone(),
68            fut,
69            callback,
70            self.stop_signal.clone(),
71        );
72        self.tasks.lock().push(task);
73    }
74
75    /// Checks if the TaskGroup is empty.
76    pub fn is_empty(&self) -> bool {
77        self.tasks.lock().is_empty()
78    }
79
80    /// Get the number of the tasks in the group.
81    pub fn len(&self) -> usize {
82        self.tasks.lock().len()
83    }
84
85    /// Cancels all tasks in the group.
86    pub async fn cancel(&self) {
87        self.stop_signal.broadcast().await;
88
89        loop {
90            // XXX BE CAREFUL HERE, it hold synchronous mutex across .await point.
91            let task = self.tasks.lock().pop();
92            if let Some(t) = task {
93                t.cancel().await
94            } else {
95                break;
96            }
97        }
98    }
99}
100
101impl Default for TaskGroup {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107/// The result of a spawned task.
108#[derive(Debug)]
109pub enum TaskResult<T> {
110    Completed(T),
111    Cancelled,
112}
113
114impl<T: std::fmt::Debug> std::fmt::Display for TaskResult<T> {
115    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
116        match self {
117            TaskResult::Cancelled => write!(f, "Task cancelled"),
118            TaskResult::Completed(res) => write!(f, "Task completed: {:?}", res),
119        }
120    }
121}
122
123/// TaskHandler
124pub struct TaskHandler {
125    task: Task<()>,
126    cancel_flag: Arc<CondWait>,
127}
128
129impl TaskHandler {
130    /// Creates a new task handler
131    fn new<T, Fut, CallbackF, CallbackFut>(
132        ex: Executor,
133        fut: Fut,
134        callback: CallbackF,
135        stop_signal: Arc<CondWait>,
136    ) -> TaskHandler
137    where
138        T: Send + Sync + 'static,
139        Fut: Future<Output = T> + Send + 'static,
140        CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'static,
141        CallbackFut: Future<Output = ()> + Send + 'static,
142    {
143        let cancel_flag = Arc::new(CondWait::new());
144        let cancel_flag_c = cancel_flag.clone();
145        let task = ex.spawn(async move {
146            // Waits for either the stop signal or the task to complete.
147            let result = select(stop_signal.wait(), fut).await;
148
149            let result = match result {
150                Either::Left(_) => TaskResult::Cancelled,
151                Either::Right(res) => TaskResult::Completed(res),
152            };
153
154            // Call the callback
155            callback(result).await;
156
157            cancel_flag_c.signal().await;
158        });
159
160        TaskHandler { task, cancel_flag }
161    }
162
163    /// Cancels the task.
164    async fn cancel(self) {
165        self.cancel_flag.wait().await;
166        self.task.cancel().await;
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use std::{future, sync::Arc};
173
174    use crate::async_runtime::block_on;
175    use crate::async_util::sleep;
176
177    use super::*;
178
179    #[cfg(feature = "tokio")]
180    #[test]
181    fn test_task_group_with_tokio_executor() {
182        let ex = Arc::new(tokio::runtime::Runtime::new().unwrap());
183        ex.clone().block_on(async move {
184            let group = Arc::new(TaskGroup::with_executor(ex.into()));
185
186            group.spawn(future::ready(0), |res| async move {
187                assert!(matches!(res, TaskResult::Completed(0)));
188            });
189
190            group.spawn(future::pending::<()>(), |res| async move {
191                assert!(matches!(res, TaskResult::Cancelled));
192            });
193
194            let groupc = group.clone();
195            group.spawn(
196                async move {
197                    groupc.spawn(future::pending::<()>(), |res| async move {
198                        assert!(matches!(res, TaskResult::Cancelled));
199                    });
200                },
201                |res| async move {
202                    assert!(matches!(res, TaskResult::Completed(_)));
203                },
204            );
205
206            // Do something
207            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
208            group.cancel().await;
209        });
210    }
211
212    #[cfg(feature = "smol")]
213    #[test]
214    fn test_task_group_with_smol_executor() {
215        let ex = Arc::new(smol::Executor::new());
216        smol::block_on(ex.clone().run(async move {
217            let group = Arc::new(TaskGroup::with_executor(ex.into()));
218
219            group.spawn(future::ready(0), |res| async move {
220                assert!(matches!(res, TaskResult::Completed(0)));
221            });
222
223            group.spawn(future::pending::<()>(), |res| async move {
224                assert!(matches!(res, TaskResult::Cancelled));
225            });
226
227            let groupc = group.clone();
228            group.spawn(
229                async move {
230                    groupc.spawn(future::pending::<()>(), |res| async move {
231                        assert!(matches!(res, TaskResult::Cancelled));
232                    });
233                },
234                |res| async move {
235                    assert!(matches!(res, TaskResult::Completed(_)));
236                },
237            );
238
239            // Do something
240            smol::Timer::after(std::time::Duration::from_millis(50)).await;
241            group.cancel().await;
242        }));
243    }
244
245    #[test]
246    fn test_task_group() {
247        block_on(async {
248            let group = Arc::new(TaskGroup::new());
249
250            group.spawn(future::ready(0), |res| async move {
251                assert!(matches!(res, TaskResult::Completed(0)));
252            });
253
254            group.spawn(future::pending::<()>(), |res| async move {
255                assert!(matches!(res, TaskResult::Cancelled));
256            });
257
258            let groupc = group.clone();
259            group.spawn(
260                async move {
261                    groupc.spawn(future::pending::<()>(), |res| async move {
262                        assert!(matches!(res, TaskResult::Cancelled));
263                    });
264                },
265                |res| async move {
266                    assert!(matches!(res, TaskResult::Completed(_)));
267                },
268            );
269
270            // Do something
271            sleep(std::time::Duration::from_millis(50)).await;
272            group.cancel().await;
273        });
274    }
275}