karyon_core/async_util/
task_group.rs1use std::{future::Future, sync::Arc};
2
3use parking_lot::Mutex;
4
5use crate::async_runtime::{global_executor, Executor, Task};
6
7use super::{select, CondWait, Either};
8
9pub struct TaskGroup {
30 tasks: Mutex<Vec<TaskHandler>>,
31 stop_signal: Arc<CondWait>,
32 executor: Executor,
33}
34
35impl TaskGroup {
36 pub fn new() -> Self {
40 Self {
41 tasks: Mutex::new(Vec::new()),
42 stop_signal: Arc::new(CondWait::new()),
43 executor: global_executor(),
44 }
45 }
46
47 pub fn with_executor(executor: Executor) -> Self {
49 Self {
50 tasks: Mutex::new(Vec::new()),
51 stop_signal: Arc::new(CondWait::new()),
52 executor,
53 }
54 }
55
56 pub fn spawn<T, Fut, CallbackF, CallbackFut>(&self, fut: Fut, callback: CallbackF)
60 where
61 T: Send + Sync + 'static,
62 Fut: Future<Output = T> + Send + 'static,
63 CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'static,
64 CallbackFut: Future<Output = ()> + Send + 'static,
65 {
66 let task = TaskHandler::new(
67 self.executor.clone(),
68 fut,
69 callback,
70 self.stop_signal.clone(),
71 );
72 self.tasks.lock().push(task);
73 }
74
75 pub fn is_empty(&self) -> bool {
77 self.tasks.lock().is_empty()
78 }
79
80 pub fn len(&self) -> usize {
82 self.tasks.lock().len()
83 }
84
85 pub async fn cancel(&self) {
87 self.stop_signal.broadcast().await;
88
89 loop {
90 let task = self.tasks.lock().pop();
92 if let Some(t) = task {
93 t.cancel().await
94 } else {
95 break;
96 }
97 }
98 }
99}
100
101impl Default for TaskGroup {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107#[derive(Debug)]
109pub enum TaskResult<T> {
110 Completed(T),
111 Cancelled,
112}
113
114impl<T: std::fmt::Debug> std::fmt::Display for TaskResult<T> {
115 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
116 match self {
117 TaskResult::Cancelled => write!(f, "Task cancelled"),
118 TaskResult::Completed(res) => write!(f, "Task completed: {:?}", res),
119 }
120 }
121}
122
123pub struct TaskHandler {
125 task: Task<()>,
126 cancel_flag: Arc<CondWait>,
127}
128
129impl TaskHandler {
130 fn new<T, Fut, CallbackF, CallbackFut>(
132 ex: Executor,
133 fut: Fut,
134 callback: CallbackF,
135 stop_signal: Arc<CondWait>,
136 ) -> TaskHandler
137 where
138 T: Send + Sync + 'static,
139 Fut: Future<Output = T> + Send + 'static,
140 CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'static,
141 CallbackFut: Future<Output = ()> + Send + 'static,
142 {
143 let cancel_flag = Arc::new(CondWait::new());
144 let cancel_flag_c = cancel_flag.clone();
145 let task = ex.spawn(async move {
146 let result = select(stop_signal.wait(), fut).await;
148
149 let result = match result {
150 Either::Left(_) => TaskResult::Cancelled,
151 Either::Right(res) => TaskResult::Completed(res),
152 };
153
154 callback(result).await;
156
157 cancel_flag_c.signal().await;
158 });
159
160 TaskHandler { task, cancel_flag }
161 }
162
163 async fn cancel(self) {
165 self.cancel_flag.wait().await;
166 self.task.cancel().await;
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use std::{future, sync::Arc};
173
174 use crate::async_runtime::block_on;
175 use crate::async_util::sleep;
176
177 use super::*;
178
179 #[cfg(feature = "tokio")]
180 #[test]
181 fn test_task_group_with_tokio_executor() {
182 let ex = Arc::new(tokio::runtime::Runtime::new().unwrap());
183 ex.clone().block_on(async move {
184 let group = Arc::new(TaskGroup::with_executor(ex.into()));
185
186 group.spawn(future::ready(0), |res| async move {
187 assert!(matches!(res, TaskResult::Completed(0)));
188 });
189
190 group.spawn(future::pending::<()>(), |res| async move {
191 assert!(matches!(res, TaskResult::Cancelled));
192 });
193
194 let groupc = group.clone();
195 group.spawn(
196 async move {
197 groupc.spawn(future::pending::<()>(), |res| async move {
198 assert!(matches!(res, TaskResult::Cancelled));
199 });
200 },
201 |res| async move {
202 assert!(matches!(res, TaskResult::Completed(_)));
203 },
204 );
205
206 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
208 group.cancel().await;
209 });
210 }
211
212 #[cfg(feature = "smol")]
213 #[test]
214 fn test_task_group_with_smol_executor() {
215 let ex = Arc::new(smol::Executor::new());
216 smol::block_on(ex.clone().run(async move {
217 let group = Arc::new(TaskGroup::with_executor(ex.into()));
218
219 group.spawn(future::ready(0), |res| async move {
220 assert!(matches!(res, TaskResult::Completed(0)));
221 });
222
223 group.spawn(future::pending::<()>(), |res| async move {
224 assert!(matches!(res, TaskResult::Cancelled));
225 });
226
227 let groupc = group.clone();
228 group.spawn(
229 async move {
230 groupc.spawn(future::pending::<()>(), |res| async move {
231 assert!(matches!(res, TaskResult::Cancelled));
232 });
233 },
234 |res| async move {
235 assert!(matches!(res, TaskResult::Completed(_)));
236 },
237 );
238
239 smol::Timer::after(std::time::Duration::from_millis(50)).await;
241 group.cancel().await;
242 }));
243 }
244
245 #[test]
246 fn test_task_group() {
247 block_on(async {
248 let group = Arc::new(TaskGroup::new());
249
250 group.spawn(future::ready(0), |res| async move {
251 assert!(matches!(res, TaskResult::Completed(0)));
252 });
253
254 group.spawn(future::pending::<()>(), |res| async move {
255 assert!(matches!(res, TaskResult::Cancelled));
256 });
257
258 let groupc = group.clone();
259 group.spawn(
260 async move {
261 groupc.spawn(future::pending::<()>(), |res| async move {
262 assert!(matches!(res, TaskResult::Cancelled));
263 });
264 },
265 |res| async move {
266 assert!(matches!(res, TaskResult::Completed(_)));
267 },
268 );
269
270 sleep(std::time::Duration::from_millis(50)).await;
272 group.cancel().await;
273 });
274 }
275}