dara-core 1.20.1a1__py3-none-any.whl → 1.20.1a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dara/core/__init__.py +3 -0
- dara/core/actions.py +1 -2
- dara/core/auth/basic.py +22 -16
- dara/core/auth/definitions.py +2 -2
- dara/core/auth/routes.py +5 -5
- dara/core/auth/utils.py +5 -5
- dara/core/base_definitions.py +22 -64
- dara/core/cli.py +8 -7
- dara/core/configuration.py +5 -2
- dara/core/css.py +1 -2
- dara/core/data_utils.py +18 -19
- dara/core/defaults.py +6 -7
- dara/core/definitions.py +50 -19
- dara/core/http.py +7 -3
- dara/core/interactivity/__init__.py +6 -0
- dara/core/interactivity/actions.py +52 -50
- dara/core/interactivity/any_data_variable.py +7 -134
- dara/core/interactivity/any_variable.py +5 -8
- dara/core/interactivity/client_variable.py +71 -0
- dara/core/interactivity/data_variable.py +8 -266
- dara/core/interactivity/derived_data_variable.py +7 -290
- dara/core/interactivity/derived_variable.py +416 -176
- dara/core/interactivity/filtering.py +46 -27
- dara/core/interactivity/loop_variable.py +2 -2
- dara/core/interactivity/non_data_variable.py +5 -68
- dara/core/interactivity/plain_variable.py +89 -15
- dara/core/interactivity/server_variable.py +325 -0
- dara/core/interactivity/state_variable.py +69 -0
- dara/core/interactivity/switch_variable.py +19 -19
- dara/core/interactivity/tabular_variable.py +94 -0
- dara/core/interactivity/url_variable.py +10 -90
- dara/core/internal/cache_store/base_impl.py +2 -1
- dara/core/internal/cache_store/cache_store.py +22 -25
- dara/core/internal/cache_store/keep_all.py +4 -1
- dara/core/internal/cache_store/lru.py +5 -1
- dara/core/internal/cache_store/ttl.py +4 -1
- dara/core/internal/cgroup.py +1 -1
- dara/core/internal/dependency_resolution.py +60 -66
- dara/core/internal/devtools.py +12 -5
- dara/core/internal/download.py +13 -4
- dara/core/internal/encoder_registry.py +7 -7
- dara/core/internal/execute_action.py +13 -13
- dara/core/internal/hashing.py +1 -3
- dara/core/internal/import_discovery.py +3 -4
- dara/core/internal/multi_resource_lock.py +70 -0
- dara/core/internal/normalization.py +9 -18
- dara/core/internal/pandas_utils.py +107 -5
- dara/core/internal/pool/definitions.py +1 -1
- dara/core/internal/pool/task_pool.py +25 -16
- dara/core/internal/pool/utils.py +21 -18
- dara/core/internal/pool/worker.py +3 -2
- dara/core/internal/port_utils.py +1 -1
- dara/core/internal/registries.py +12 -6
- dara/core/internal/registry.py +4 -2
- dara/core/internal/registry_lookup.py +11 -5
- dara/core/internal/routing.py +109 -145
- dara/core/internal/scheduler.py +13 -8
- dara/core/internal/settings.py +2 -2
- dara/core/internal/store.py +2 -29
- dara/core/internal/tasks.py +379 -195
- dara/core/internal/utils.py +36 -13
- dara/core/internal/websocket.py +21 -20
- dara/core/js_tooling/js_utils.py +28 -26
- dara/core/js_tooling/templates/vite.config.template.ts +12 -3
- dara/core/logging.py +13 -12
- dara/core/main.py +14 -11
- dara/core/metrics/cache.py +1 -1
- dara/core/metrics/utils.py +3 -3
- dara/core/persistence.py +27 -5
- dara/core/umd/dara.core.umd.js +68291 -64718
- dara/core/visual/components/__init__.py +2 -2
- dara/core/visual/components/fallback.py +30 -4
- dara/core/visual/components/for_cmp.py +4 -1
- dara/core/visual/css/__init__.py +30 -31
- dara/core/visual/dynamic_component.py +31 -28
- dara/core/visual/progress_updater.py +4 -3
- {dara_core-1.20.1a1.dist-info → dara_core-1.20.1a3.dist-info}/METADATA +12 -11
- dara_core-1.20.1a3.dist-info/RECORD +119 -0
- dara_core-1.20.1a1.dist-info/RECORD +0 -114
- {dara_core-1.20.1a1.dist-info → dara_core-1.20.1a3.dist-info}/LICENSE +0 -0
- {dara_core-1.20.1a1.dist-info → dara_core-1.20.1a3.dist-info}/WHEEL +0 -0
- {dara_core-1.20.1a1.dist-info → dara_core-1.20.1a3.dist-info}/entry_points.txt +0 -0
dara/core/internal/tasks.py
CHANGED
|
@@ -15,26 +15,30 @@ See the License for the specific language governing permissions and
|
|
|
15
15
|
limitations under the License.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
+
import contextlib
|
|
18
19
|
import inspect
|
|
19
20
|
import math
|
|
20
|
-
from
|
|
21
|
+
from collections.abc import Awaitable
|
|
22
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Union, overload
|
|
21
23
|
|
|
22
24
|
from anyio import (
|
|
25
|
+
BrokenResourceError,
|
|
23
26
|
CancelScope,
|
|
24
27
|
ClosedResourceError,
|
|
25
28
|
create_memory_object_stream,
|
|
26
29
|
create_task_group,
|
|
30
|
+
get_cancelled_exc_class,
|
|
27
31
|
move_on_after,
|
|
28
32
|
)
|
|
29
33
|
from anyio.abc import TaskGroup
|
|
30
34
|
from anyio.streams.memory import MemoryObjectSendStream
|
|
31
|
-
from exceptiongroup import ExceptionGroup
|
|
35
|
+
from exceptiongroup import ExceptionGroup, catch
|
|
32
36
|
from pydantic import ConfigDict
|
|
33
37
|
|
|
34
38
|
from dara.core.base_definitions import (
|
|
35
39
|
BaseTask,
|
|
36
|
-
Cache,
|
|
37
40
|
CachedRegistryEntry,
|
|
41
|
+
LruCachePolicy,
|
|
38
42
|
PendingTask,
|
|
39
43
|
TaskError,
|
|
40
44
|
TaskMessage,
|
|
@@ -45,7 +49,7 @@ from dara.core.internal.cache_store import CacheStore
|
|
|
45
49
|
from dara.core.internal.devtools import get_error_for_channel
|
|
46
50
|
from dara.core.internal.pandas_utils import remove_index
|
|
47
51
|
from dara.core.internal.pool import TaskPool
|
|
48
|
-
from dara.core.internal.utils import
|
|
52
|
+
from dara.core.internal.utils import exception_group_contains, run_user_handler
|
|
49
53
|
from dara.core.internal.websocket import WebsocketManager
|
|
50
54
|
from dara.core.logging import dev_logger, eng_logger
|
|
51
55
|
from dara.core.metrics import RUNTIME_METRICS_TRACKER
|
|
@@ -137,14 +141,12 @@ class Task(BaseTask):
|
|
|
137
141
|
|
|
138
142
|
async def on_progress(progress: float, msg: str):
|
|
139
143
|
if send_stream is not None:
|
|
140
|
-
|
|
144
|
+
with contextlib.suppress(ClosedResourceError):
|
|
141
145
|
await send_stream.send(TaskProgressUpdate(task_id=self.task_id, progress=progress, message=msg))
|
|
142
|
-
except ClosedResourceError:
|
|
143
|
-
pass
|
|
144
146
|
|
|
145
147
|
async def on_result(result: Any):
|
|
146
148
|
if send_stream is not None:
|
|
147
|
-
|
|
149
|
+
with contextlib.suppress(ClosedResourceError):
|
|
148
150
|
await send_stream.send(
|
|
149
151
|
TaskResult(
|
|
150
152
|
task_id=self.task_id,
|
|
@@ -153,19 +155,15 @@ class Task(BaseTask):
|
|
|
153
155
|
reg_entry=self.reg_entry,
|
|
154
156
|
)
|
|
155
157
|
)
|
|
156
|
-
except ClosedResourceError:
|
|
157
|
-
pass
|
|
158
158
|
|
|
159
159
|
async def on_error(exc: BaseException):
|
|
160
160
|
if send_stream is not None:
|
|
161
|
-
|
|
161
|
+
with contextlib.suppress(ClosedResourceError):
|
|
162
162
|
await send_stream.send(
|
|
163
163
|
TaskError(
|
|
164
164
|
task_id=self.task_id, error=exc, cache_key=self.cache_key, reg_entry=self.reg_entry
|
|
165
165
|
)
|
|
166
166
|
)
|
|
167
|
-
except ClosedResourceError:
|
|
168
|
-
pass
|
|
169
167
|
|
|
170
168
|
with pool.on_progress(self.task_id, on_progress):
|
|
171
169
|
pool_task_def = pool.submit(self.task_id, self._func_name, args=tuple(self._args), kwargs=self._kwargs)
|
|
@@ -218,7 +216,7 @@ class MetaTask(BaseTask):
|
|
|
218
216
|
:param notify_channels: If this task is run in a TaskManager instance these channels will also be notified on
|
|
219
217
|
completion
|
|
220
218
|
:param process_as_task: Whether to run the process_result function as a task or not, defaults to False
|
|
221
|
-
:param cache_key: Optional cache key if there is a
|
|
219
|
+
:param cache_key: Optional cache key if there is a registry entry to store results for the task in
|
|
222
220
|
:param task_id: Optional task_id to set for the task - otherwise the task generates its id automatically
|
|
223
221
|
"""
|
|
224
222
|
self.args = args if args is not None else []
|
|
@@ -239,96 +237,122 @@ class MetaTask(BaseTask):
|
|
|
239
237
|
|
|
240
238
|
:param send_stream: The stream to send messages to the task manager on
|
|
241
239
|
"""
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
# Collect up the tasks that need to be run and kick them off without awaiting them.
|
|
245
|
-
tasks.extend(x for x in self.args if isinstance(x, BaseTask))
|
|
246
|
-
tasks.extend(x for x in self.kwargs.values() if isinstance(x, BaseTask))
|
|
247
|
-
|
|
248
|
-
eng_logger.info(f'MetaTask {self.task_id} running sub-tasks', {'task_ids': [x.task_id for x in tasks]})
|
|
240
|
+
self.cancel_scope = CancelScope()
|
|
249
241
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
async
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
242
|
+
try:
|
|
243
|
+
with self.cancel_scope:
|
|
244
|
+
tasks: List[BaseTask] = []
|
|
245
|
+
|
|
246
|
+
# Collect up the tasks that need to be run and kick them off without awaiting them.
|
|
247
|
+
tasks.extend(x for x in self.args if isinstance(x, BaseTask))
|
|
248
|
+
tasks.extend(x for x in self.kwargs.values() if isinstance(x, BaseTask))
|
|
249
|
+
|
|
250
|
+
eng_logger.info(f'MetaTask {self.task_id} running sub-tasks', {'task_ids': [x.task_id for x in tasks]})
|
|
251
|
+
|
|
252
|
+
# Wait for all tasks to complete
|
|
253
|
+
results: Dict[str, Any] = {}
|
|
254
|
+
|
|
255
|
+
async def _run_and_capture_result(task: BaseTask):
|
|
256
|
+
"""
|
|
257
|
+
Run a task and capture the result
|
|
258
|
+
"""
|
|
259
|
+
nonlocal results
|
|
260
|
+
result = await task.run(send_stream)
|
|
261
|
+
results[task.task_id] = result
|
|
262
|
+
|
|
263
|
+
def handle_exception(err: ExceptionGroup):
|
|
264
|
+
with contextlib.suppress(ClosedResourceError, BrokenResourceError):
|
|
265
|
+
if send_stream is not None:
|
|
266
|
+
send_stream.send_nowait(
|
|
267
|
+
TaskError(
|
|
268
|
+
task_id=self.task_id, error=err, cache_key=self.cache_key, reg_entry=self.reg_entry
|
|
269
|
+
)
|
|
270
|
+
)
|
|
271
|
+
raise err
|
|
272
|
+
|
|
273
|
+
if len(tasks) > 0:
|
|
274
|
+
with catch(
|
|
275
|
+
{
|
|
276
|
+
BaseException: handle_exception, # type: ignore
|
|
277
|
+
get_cancelled_exc_class(): handle_exception,
|
|
278
|
+
}
|
|
279
|
+
):
|
|
280
|
+
async with create_task_group() as tg:
|
|
281
|
+
for task in tasks:
|
|
282
|
+
tg.start_soon(_run_and_capture_result, task)
|
|
283
|
+
|
|
284
|
+
# In testing somehow sometimes cancellations aren't bubbled up, this ensures that's the case
|
|
285
|
+
if tg.cancel_scope.cancel_called:
|
|
286
|
+
raise get_cancelled_exc_class()('MetaTask caught cancellation')
|
|
287
|
+
|
|
288
|
+
eng_logger.debug(f'MetaTask {self.task_id}', 'completed sub-tasks', results)
|
|
289
|
+
|
|
290
|
+
# Order the results in the same order as the tasks list
|
|
291
|
+
result_values = [results[task.task_id] for task in tasks]
|
|
292
|
+
|
|
293
|
+
args = []
|
|
294
|
+
kwargs = {}
|
|
295
|
+
|
|
296
|
+
# Rebuild the args and kwargs with the results of the underlying tasks
|
|
297
|
+
# Here the task results could be DataFrames so make sure we clean the internal __index__ col from them
|
|
298
|
+
# before passing into the task function
|
|
299
|
+
for arg in self.args:
|
|
300
|
+
if isinstance(arg, BaseTask):
|
|
301
|
+
args.append(remove_index(result_values.pop(0)))
|
|
302
|
+
else:
|
|
303
|
+
args.append(arg)
|
|
304
|
+
|
|
305
|
+
for k, val in self.kwargs.items():
|
|
306
|
+
if isinstance(val, BaseTask):
|
|
307
|
+
kwargs[k] = remove_index(result_values.pop(0))
|
|
308
|
+
else:
|
|
309
|
+
kwargs[k] = val
|
|
310
|
+
|
|
311
|
+
eng_logger.debug(f'MetaTask {self.task_id}', 'processing result', {'args': args, 'kwargs': kwargs})
|
|
312
|
+
|
|
313
|
+
# Run the process result function with the completed set of args and kwargs
|
|
314
|
+
if self.process_as_task:
|
|
315
|
+
eng_logger.debug(f'MetaTask {self.task_id}', 'processing result as Task')
|
|
316
|
+
|
|
317
|
+
task = Task(
|
|
318
|
+
self.process_result,
|
|
319
|
+
args,
|
|
320
|
+
kwargs,
|
|
321
|
+
# Pass through cache_key so the processing task correctly updates the cache store entry
|
|
322
|
+
reg_entry=self.reg_entry,
|
|
323
|
+
cache_key=self.cache_key,
|
|
324
|
+
task_id=self.task_id,
|
|
271
325
|
)
|
|
272
|
-
|
|
273
|
-
finally:
|
|
274
|
-
self.cancel_scope = None
|
|
275
|
-
|
|
276
|
-
eng_logger.debug(f'MetaTask {self.task_id}', 'completed sub-tasks', results)
|
|
326
|
+
res = await task.run(send_stream)
|
|
277
327
|
|
|
278
|
-
|
|
279
|
-
result_values = [results[task.task_id] for task in tasks]
|
|
328
|
+
eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
|
|
280
329
|
|
|
281
|
-
|
|
282
|
-
kwargs = {}
|
|
330
|
+
return res
|
|
283
331
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
# before passing into the task function
|
|
287
|
-
for arg in self.args:
|
|
288
|
-
if isinstance(arg, BaseTask):
|
|
289
|
-
args.append(remove_index(result_values.pop(0)))
|
|
290
|
-
else:
|
|
291
|
-
args.append(arg)
|
|
292
|
-
|
|
293
|
-
for k, val in self.kwargs.items():
|
|
294
|
-
if isinstance(val, BaseTask):
|
|
295
|
-
kwargs[k] = remove_index(result_values.pop(0))
|
|
296
|
-
else:
|
|
297
|
-
kwargs[k] = val
|
|
298
|
-
|
|
299
|
-
eng_logger.debug(f'MetaTask {self.task_id}', 'processing result', {'args': args, 'kwargs': kwargs})
|
|
300
|
-
|
|
301
|
-
# Run the process result function with the completed set of args and kwargs
|
|
302
|
-
if self.process_as_task:
|
|
303
|
-
eng_logger.debug(f'MetaTask {self.task_id}', 'processing result as Task')
|
|
304
|
-
# Pass through cache_key so the processing task correctly updates the cache store entry
|
|
305
|
-
task = Task(self.process_result, args, kwargs, cache_key=self.cache_key)
|
|
306
|
-
res = await task.run(send_stream)
|
|
332
|
+
try:
|
|
333
|
+
res = await run_user_handler(self.process_result, args, kwargs)
|
|
307
334
|
|
|
308
|
-
|
|
335
|
+
# Send MetaTask result - it could be that there is a nested structure
|
|
336
|
+
# of MetaTasks so we need to make sure intermediate results are also sent
|
|
337
|
+
if send_stream is not None:
|
|
338
|
+
await send_stream.send(
|
|
339
|
+
TaskResult(
|
|
340
|
+
task_id=self.task_id, result=res, cache_key=self.cache_key, reg_entry=self.reg_entry
|
|
341
|
+
)
|
|
342
|
+
)
|
|
343
|
+
except BaseException as e:
|
|
344
|
+
# Recover from error - update the pending value to prevent subsequent requests getting stuck
|
|
345
|
+
if send_stream is not None:
|
|
346
|
+
await send_stream.send(
|
|
347
|
+
TaskError(task_id=self.task_id, error=e, cache_key=self.cache_key, reg_entry=self.reg_entry)
|
|
348
|
+
)
|
|
349
|
+
raise
|
|
309
350
|
|
|
310
|
-
|
|
351
|
+
eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
|
|
311
352
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
# Send MetaTask result - it could be that there is a nested structure
|
|
316
|
-
# of MetaTasks so we need to make sure intermediate results are also sent
|
|
317
|
-
if send_stream is not None:
|
|
318
|
-
await send_stream.send(
|
|
319
|
-
TaskResult(task_id=self.task_id, result=res, cache_key=self.cache_key, reg_entry=self.reg_entry)
|
|
320
|
-
)
|
|
321
|
-
except BaseException as e:
|
|
322
|
-
# Recover from error - update the pending value to prevent subsequent requests getting stuck
|
|
323
|
-
if send_stream is not None:
|
|
324
|
-
await send_stream.send(
|
|
325
|
-
TaskError(task_id=self.task_id, error=e, cache_key=self.cache_key, reg_entry=self.reg_entry)
|
|
326
|
-
)
|
|
327
|
-
raise
|
|
328
|
-
|
|
329
|
-
eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
|
|
330
|
-
|
|
331
|
-
return res
|
|
353
|
+
return res
|
|
354
|
+
finally:
|
|
355
|
+
self.cancel_scope = None
|
|
332
356
|
|
|
333
357
|
async def cancel(self):
|
|
334
358
|
"""
|
|
@@ -347,8 +371,10 @@ class TaskManager:
|
|
|
347
371
|
TaskManager is responsible for running tasks and managing their pending state. It is also responsible for
|
|
348
372
|
communicating the state of tasks to the client via the WebsocketManager.
|
|
349
373
|
|
|
350
|
-
|
|
351
|
-
|
|
374
|
+
Every task created gets registered with the TaskManager and is tracked by the TaskManager
|
|
375
|
+
as a PendingTask.
|
|
376
|
+
|
|
377
|
+
This allows the task to be retrieved by the cache_key from the store.
|
|
352
378
|
|
|
353
379
|
When a task is completed, it is removed from the tasks dict and the store entry is updated with the result.
|
|
354
380
|
|
|
@@ -361,13 +387,19 @@ class TaskManager:
|
|
|
361
387
|
self.ws_manager = ws_manager
|
|
362
388
|
self.store = store
|
|
363
389
|
|
|
390
|
+
def register_task(self, task: BaseTask) -> PendingTask:
|
|
391
|
+
"""
|
|
392
|
+
Register a task. This will ensure the task it tracked and notifications are routed correctly.
|
|
393
|
+
"""
|
|
394
|
+
pending_task = PendingTask(task.task_id, task)
|
|
395
|
+
self.tasks[task.task_id] = pending_task
|
|
396
|
+
return pending_task
|
|
397
|
+
|
|
364
398
|
@overload
|
|
365
|
-
async def run_task(self, task: PendingTask, ws_channel: Optional[str] = None) -> Any:
|
|
366
|
-
...
|
|
399
|
+
async def run_task(self, task: PendingTask, ws_channel: Optional[str] = None) -> Any: ...
|
|
367
400
|
|
|
368
401
|
@overload
|
|
369
|
-
async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None) -> PendingTask:
|
|
370
|
-
...
|
|
402
|
+
async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None) -> PendingTask: ...
|
|
371
403
|
|
|
372
404
|
async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None):
|
|
373
405
|
"""
|
|
@@ -380,7 +412,9 @@ class TaskManager:
|
|
|
380
412
|
# append the websocket channel to the task
|
|
381
413
|
if isinstance(task, PendingTask):
|
|
382
414
|
if task.task_id in self.tasks:
|
|
383
|
-
|
|
415
|
+
# Increment subscriber count for this component request
|
|
416
|
+
self.tasks[task.task_id].add_subscriber()
|
|
417
|
+
if ws_channel is not None:
|
|
384
418
|
self.tasks[task.task_id].notify_channels.append(ws_channel)
|
|
385
419
|
return self.tasks[task.task_id]
|
|
386
420
|
|
|
@@ -394,21 +428,65 @@ class TaskManager:
|
|
|
394
428
|
else self.get_result(task.task_id)
|
|
395
429
|
)
|
|
396
430
|
|
|
397
|
-
#
|
|
398
|
-
pending_task =
|
|
399
|
-
if
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
self.tasks[task.task_id] = pending_task
|
|
431
|
+
# Otherwise, we should already have a pending task for this task
|
|
432
|
+
pending_task = self.tasks[task.task_id]
|
|
433
|
+
if ws_channel is not None:
|
|
434
|
+
pending_task.notify_channels.append(ws_channel)
|
|
403
435
|
|
|
404
436
|
# Run the task in the background
|
|
405
|
-
self.task_group.start_soon(self._run_task_and_notify, task
|
|
437
|
+
self.task_group.start_soon(self._run_task_and_notify, task)
|
|
406
438
|
|
|
407
439
|
return pending_task
|
|
408
440
|
|
|
441
|
+
async def _cancel_tasks(self, task_ids: List[str], notify: bool = True):
|
|
442
|
+
"""
|
|
443
|
+
Cancel a list of tasks
|
|
444
|
+
|
|
445
|
+
:param task_ids: The list of task IDs to cancel
|
|
446
|
+
:param notify: Whether to send cancellation notifications
|
|
447
|
+
"""
|
|
448
|
+
with CancelScope(shield=True):
|
|
449
|
+
# Cancel all tasks in the hierarchy
|
|
450
|
+
for task_id_to_cancel in task_ids:
|
|
451
|
+
if task_id_to_cancel in self.tasks:
|
|
452
|
+
pending_task = self.tasks[task_id_to_cancel]
|
|
453
|
+
|
|
454
|
+
# Notify channels that this specific task was cancelled
|
|
455
|
+
if notify:
|
|
456
|
+
await self._send_notification_for_pending_task(
|
|
457
|
+
pending_task=pending_task,
|
|
458
|
+
messages=[{'status': 'CANCELED', 'task_id': task_id_to_cancel}],
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
if not pending_task.event.is_set():
|
|
462
|
+
# Cancel the actual task
|
|
463
|
+
await pending_task.cancel()
|
|
464
|
+
|
|
465
|
+
# Remove from cache if it has cache settings
|
|
466
|
+
if pending_task.task_def.cache_key is not None and pending_task.task_def.reg_entry is not None:
|
|
467
|
+
await self.store.delete(
|
|
468
|
+
pending_task.task_def.reg_entry, key=pending_task.task_def.cache_key
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
# Remove from running tasks
|
|
472
|
+
self.tasks.pop(task_id_to_cancel, None)
|
|
473
|
+
|
|
474
|
+
dev_logger.info('Task cancelled', {'task_id': task_id_to_cancel})
|
|
475
|
+
|
|
476
|
+
async def _cancel_task_hierarchy(self, task: BaseTask, notify: bool = True):
|
|
477
|
+
"""
|
|
478
|
+
Recursively cancel all tasks in a task hierarchy
|
|
479
|
+
|
|
480
|
+
:param task: The root task to cancel (and its children)
|
|
481
|
+
:param notify: Whether to send cancellation notifications
|
|
482
|
+
"""
|
|
483
|
+
all_task_ids = self._collect_all_task_ids_in_hierarchy(task)
|
|
484
|
+
await self._cancel_tasks(list(all_task_ids), notify)
|
|
485
|
+
|
|
409
486
|
async def cancel_task(self, task_id: str, notify: bool = True):
|
|
410
487
|
"""
|
|
411
|
-
Cancel a running task by its id
|
|
488
|
+
Cancel a running task by its id. If the task has child tasks (MetaTask),
|
|
489
|
+
all child tasks will also be cancelled.
|
|
412
490
|
|
|
413
491
|
:param task_id: the id of the task
|
|
414
492
|
:param notify: whether to notify, true by default
|
|
@@ -423,22 +501,9 @@ class TaskManager:
|
|
|
423
501
|
task.remove_subscriber()
|
|
424
502
|
return
|
|
425
503
|
|
|
426
|
-
#
|
|
427
|
-
|
|
428
|
-
for channel in [*task.notify_channels, *task.task_def.notify_channels]:
|
|
429
|
-
await self.ws_manager.send_message(channel, {'status': 'CANCELED', 'task_id': task_id})
|
|
430
|
-
|
|
431
|
-
# We're only now cancelling the task to make sure the clients are notified about cancelling
|
|
432
|
-
# and receive the correct status rather than an error
|
|
433
|
-
await task.cancel()
|
|
434
|
-
|
|
435
|
-
# Then remove the pending task from cache so next requests would recalculate rather than receive
|
|
436
|
-
# a broken pending task
|
|
437
|
-
if task.task_def.cache_key is not None and task.task_def.reg_entry is not None:
|
|
438
|
-
await self.store.set(task.task_def.reg_entry, key=task.task_def.cache_key, value=None)
|
|
504
|
+
# Cancel the entire task hierarchy (including child tasks)
|
|
505
|
+
await self._cancel_task_hierarchy(task.task_def, notify)
|
|
439
506
|
|
|
440
|
-
# Remove from running tasks
|
|
441
|
-
self.tasks.pop(task_id, None)
|
|
442
507
|
else:
|
|
443
508
|
raise TaskManagerError('Could not find a task with the passed id to cancel.')
|
|
444
509
|
|
|
@@ -459,12 +524,9 @@ class TaskManager:
|
|
|
459
524
|
|
|
460
525
|
:param task_id: the id of the task to fetch
|
|
461
526
|
"""
|
|
462
|
-
result
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
await self.store.delete(TaskResultEntry, key=task_id)
|
|
466
|
-
|
|
467
|
-
return result
|
|
527
|
+
# the result is not deleted, the results are kept in an LRU cache
|
|
528
|
+
# which will clean up older entries
|
|
529
|
+
return await self.store.get(TaskResultEntry, key=task_id)
|
|
468
530
|
|
|
469
531
|
async def set_result(self, task_id: str, value: Any):
|
|
470
532
|
"""
|
|
@@ -472,7 +534,82 @@ class TaskManager:
|
|
|
472
534
|
"""
|
|
473
535
|
return await self.store.set(TaskResultEntry, key=task_id, value=value)
|
|
474
536
|
|
|
475
|
-
|
|
537
|
+
def _collect_all_task_ids_in_hierarchy(self, task: BaseTask) -> Set[str]:
|
|
538
|
+
"""
|
|
539
|
+
Recursively collect all task IDs in the task hierarchy
|
|
540
|
+
|
|
541
|
+
:param task: The root task to start collecting from
|
|
542
|
+
:return: Set of all task IDs in the hierarchy
|
|
543
|
+
"""
|
|
544
|
+
task_ids = {task.task_id}
|
|
545
|
+
|
|
546
|
+
if isinstance(task, MetaTask):
|
|
547
|
+
# Collect from args
|
|
548
|
+
for arg in task.args:
|
|
549
|
+
if isinstance(arg, BaseTask):
|
|
550
|
+
task_ids.update(self._collect_all_task_ids_in_hierarchy(arg))
|
|
551
|
+
|
|
552
|
+
# Collect from kwargs
|
|
553
|
+
for value in task.kwargs.values():
|
|
554
|
+
if isinstance(value, BaseTask):
|
|
555
|
+
task_ids.update(self._collect_all_task_ids_in_hierarchy(value))
|
|
556
|
+
|
|
557
|
+
return task_ids
|
|
558
|
+
|
|
559
|
+
async def _multicast_notification(self, task_id: str, messages: List[dict]):
|
|
560
|
+
"""
|
|
561
|
+
Send notifications to all task IDs that are related to a given task
|
|
562
|
+
|
|
563
|
+
:param task: the task the notifications are related to
|
|
564
|
+
:param messages: List of message dictionaries to send to all related tasks
|
|
565
|
+
"""
|
|
566
|
+
# prevent cancellation, we need the notifications to be sent
|
|
567
|
+
with CancelScope(shield=True):
|
|
568
|
+
# Find all PendingTasks that have the message_task_id in their hierarchy
|
|
569
|
+
tasks_to_notify = set()
|
|
570
|
+
|
|
571
|
+
for pending_task in self.tasks.values():
|
|
572
|
+
# Check if the message_task_id is in this PendingTask's hierarchy
|
|
573
|
+
task_ids_in_hierarchy = self._collect_all_task_ids_in_hierarchy(pending_task.task_def)
|
|
574
|
+
if task_id in task_ids_in_hierarchy:
|
|
575
|
+
tasks_to_notify.add(pending_task.task_id)
|
|
576
|
+
|
|
577
|
+
# Send notifications for all affected PendingTasks in parallel
|
|
578
|
+
if tasks_to_notify:
|
|
579
|
+
async with create_task_group() as task_tg:
|
|
580
|
+
for pending_task_id in tasks_to_notify:
|
|
581
|
+
if pending_task_id not in self.tasks:
|
|
582
|
+
continue
|
|
583
|
+
pending_task = self.tasks[pending_task_id]
|
|
584
|
+
task_tg.start_soon(self._send_notification_for_pending_task, pending_task, messages)
|
|
585
|
+
|
|
586
|
+
async def _send_notification_for_pending_task(self, pending_task: PendingTask, messages: List[dict]):
|
|
587
|
+
"""
|
|
588
|
+
Send notifications for a specific PendingTask
|
|
589
|
+
|
|
590
|
+
:param pending_task: The PendingTask to send notifications for
|
|
591
|
+
:param messages: The messages to send
|
|
592
|
+
"""
|
|
593
|
+
# Collect channels for this PendingTask
|
|
594
|
+
channels_to_notify = set(pending_task.notify_channels)
|
|
595
|
+
channels_to_notify.update(pending_task.task_def.notify_channels)
|
|
596
|
+
|
|
597
|
+
if not channels_to_notify:
|
|
598
|
+
return
|
|
599
|
+
|
|
600
|
+
# Send to all channels for this PendingTask in parallel
|
|
601
|
+
async def _send_to_channel(channel: str):
|
|
602
|
+
async with create_task_group() as channel_tg:
|
|
603
|
+
for message in messages:
|
|
604
|
+
# Create message with this PendingTask's task_id (if message has task_id)
|
|
605
|
+
message_for_task = {**message, 'task_id': pending_task.task_id} if 'task_id' in message else message
|
|
606
|
+
channel_tg.start_soon(self.ws_manager.send_message, channel, message_for_task)
|
|
607
|
+
|
|
608
|
+
async with create_task_group() as channel_tg:
|
|
609
|
+
for channel in channels_to_notify:
|
|
610
|
+
channel_tg.start_soon(_send_to_channel, channel)
|
|
611
|
+
|
|
612
|
+
async def _run_task_and_notify(self, task: BaseTask):
|
|
476
613
|
"""
|
|
477
614
|
Run the task to completion and notify the client of progress and completion
|
|
478
615
|
|
|
@@ -481,23 +618,13 @@ class TaskManager:
|
|
|
481
618
|
"""
|
|
482
619
|
cancel_scope = CancelScope()
|
|
483
620
|
|
|
484
|
-
self.tasks[task.task_id]
|
|
621
|
+
pending_task = self.tasks[task.task_id]
|
|
622
|
+
|
|
623
|
+
pending_task.cancel_scope = cancel_scope
|
|
485
624
|
|
|
486
625
|
with cancel_scope:
|
|
487
626
|
eng_logger.info(f'TaskManager running task {task.task_id}')
|
|
488
627
|
|
|
489
|
-
async def notify_channels(*messages: dict):
|
|
490
|
-
"""
|
|
491
|
-
Notify the channels of the task's progress
|
|
492
|
-
"""
|
|
493
|
-
channels_to_notify = [*task.notify_channels]
|
|
494
|
-
if ws_channel:
|
|
495
|
-
channels_to_notify.append(ws_channel)
|
|
496
|
-
|
|
497
|
-
for channel in channels_to_notify:
|
|
498
|
-
for message in messages:
|
|
499
|
-
await self.ws_manager.send_message(channel, message)
|
|
500
|
-
|
|
501
628
|
# Create a memory object stream to capture messages from the tasks
|
|
502
629
|
send_stream, receive_stream = create_memory_object_stream[TaskMessage](math.inf)
|
|
503
630
|
|
|
@@ -505,32 +632,41 @@ class TaskManager:
|
|
|
505
632
|
async with receive_stream:
|
|
506
633
|
async for message in receive_stream:
|
|
507
634
|
if isinstance(message, TaskProgressUpdate):
|
|
508
|
-
#
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
)
|
|
635
|
+
# Send progress notifications to related tasks
|
|
636
|
+
progress_message = {
|
|
637
|
+
'task_id': message.task_id, # Will be updated per task ID in multicast
|
|
638
|
+
'status': 'PROGRESS',
|
|
639
|
+
'progress': message.progress,
|
|
640
|
+
'message': message.message,
|
|
641
|
+
}
|
|
642
|
+
await self._multicast_notification(message.task_id, [progress_message])
|
|
517
643
|
if isinstance(task, Task) and task.on_progress:
|
|
518
644
|
await run_user_handler(task.on_progress, args=(message,))
|
|
519
645
|
elif isinstance(message, TaskResult):
|
|
520
|
-
#
|
|
646
|
+
# Handle dual coordination patterns:
|
|
647
|
+
# 1. Direct-coordinated tasks: resolve via active_tasks registry
|
|
521
648
|
if message.task_id in self.tasks:
|
|
522
|
-
self.tasks[
|
|
523
|
-
|
|
649
|
+
self.tasks[message.task_id].resolve(message.result)
|
|
650
|
+
|
|
651
|
+
# 2. Cache-coordinated tasks: resolve via cache store (CacheStore.set handles PendingTask resolution)
|
|
524
652
|
if (
|
|
525
653
|
message.cache_key is not None
|
|
526
654
|
and message.reg_entry is not None
|
|
527
655
|
and message.reg_entry.cache is not None
|
|
528
656
|
):
|
|
529
657
|
await self.store.set(message.reg_entry, key=message.cache_key, value=message.result)
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
658
|
+
|
|
659
|
+
# Set final result
|
|
660
|
+
await self.set_result(message.task_id, message.result)
|
|
661
|
+
|
|
662
|
+
# Notify all PendingTasks that depend on this specific task
|
|
663
|
+
await self._multicast_notification(
|
|
664
|
+
task_id=message.task_id,
|
|
665
|
+
messages=[{'result': message.result, 'status': 'COMPLETE', 'task_id': message.task_id}],
|
|
533
666
|
)
|
|
667
|
+
|
|
668
|
+
# Remove the task from the registered tasks - it finished running
|
|
669
|
+
self.tasks.pop(message.task_id, None)
|
|
534
670
|
elif isinstance(message, TaskError):
|
|
535
671
|
# Fail the pending task related to the error
|
|
536
672
|
if message.task_id in self.tasks:
|
|
@@ -543,46 +679,94 @@ class TaskManager:
|
|
|
543
679
|
and message.reg_entry is not None
|
|
544
680
|
and message.reg_entry.cache is not None
|
|
545
681
|
):
|
|
546
|
-
await self.store.
|
|
682
|
+
await self.store.delete(message.reg_entry, key=message.cache_key)
|
|
683
|
+
|
|
684
|
+
# Notify all PendingTasks that depend on this specific task
|
|
685
|
+
error = get_error_for_channel(message.error)
|
|
686
|
+
await self._multicast_notification(
|
|
687
|
+
task_id=message.task_id,
|
|
688
|
+
messages=[
|
|
689
|
+
{'status': 'ERROR', 'task_id': message.task_id, 'error': error['error']},
|
|
690
|
+
error,
|
|
691
|
+
],
|
|
692
|
+
)
|
|
547
693
|
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
# Handle incoming messages in parallel
|
|
551
|
-
tg.start_soon(handle_messages)
|
|
694
|
+
# Remove the task from the registered tasks - it finished running
|
|
695
|
+
self.tasks.pop(message.task_id, None)
|
|
552
696
|
|
|
553
|
-
|
|
554
|
-
async with send_stream:
|
|
555
|
-
result = task
|
|
556
|
-
while isinstance(result, BaseTask):
|
|
557
|
-
result = await task.run(send_stream)
|
|
697
|
+
task_error: Optional[ExceptionGroup] = None
|
|
558
698
|
|
|
559
|
-
|
|
560
|
-
|
|
699
|
+
# ExceptionGroup handler can't be async so we just mark the task as errored
|
|
700
|
+
# and run the async handler in the finally block
|
|
701
|
+
def handle_exception(err: ExceptionGroup):
|
|
702
|
+
nonlocal task_error
|
|
703
|
+
task_error = err
|
|
561
704
|
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
705
|
+
try:
|
|
706
|
+
with catch({BaseException: handle_exception}): # type: ignore
|
|
707
|
+
async with create_task_group() as tg:
|
|
708
|
+
# Handle incoming messages in parallel
|
|
709
|
+
tg.start_soon(handle_messages)
|
|
710
|
+
|
|
711
|
+
# Handle tasks that return other tasks
|
|
712
|
+
async with send_stream:
|
|
713
|
+
result = task
|
|
714
|
+
while isinstance(result, BaseTask):
|
|
715
|
+
result = await task.run(send_stream)
|
|
716
|
+
|
|
717
|
+
# Notify any channels that need to be notified about the whole task being completed
|
|
718
|
+
await send_stream.send(
|
|
719
|
+
TaskResult(
|
|
720
|
+
task_id=task.task_id,
|
|
721
|
+
result=result,
|
|
722
|
+
cache_key=task.cache_key,
|
|
723
|
+
reg_entry=task.reg_entry,
|
|
724
|
+
)
|
|
725
|
+
)
|
|
726
|
+
eng_logger.info(f'TaskManager finished task {task.task_id}', {'result': result})
|
|
727
|
+
finally:
|
|
728
|
+
with CancelScope(shield=True):
|
|
729
|
+
# pyright: ignore[reportUnreachable]
|
|
730
|
+
if task_error is not None:
|
|
731
|
+
err = task_error
|
|
732
|
+
# Mark pending task as failed
|
|
733
|
+
pending_task.fail(err)
|
|
734
|
+
|
|
735
|
+
await self.set_result(task.task_id, {'error': str(err)})
|
|
736
|
+
|
|
737
|
+
# If the task has a cache key, set cached value to None
|
|
738
|
+
# This makes it so that the next request will recalculate the value rather than keep failing
|
|
739
|
+
if (
|
|
740
|
+
task.cache_key is not None
|
|
741
|
+
and task.reg_entry is not None
|
|
742
|
+
and task.reg_entry.cache is not None
|
|
743
|
+
):
|
|
744
|
+
await self.store.delete(task.reg_entry, key=task.cache_key)
|
|
745
|
+
|
|
746
|
+
# If this is a cancellation, ensure all tasks in the hierarchy are cancelled
|
|
747
|
+
if exception_group_contains(err_type=get_cancelled_exc_class(), group=err):
|
|
748
|
+
dev_logger.info('Task cancelled', {'task_id': task.task_id})
|
|
749
|
+
|
|
750
|
+
# Cancel any remaining tasks in the hierarchy that might still be running
|
|
751
|
+
await self._cancel_task_hierarchy(task, notify=True)
|
|
752
|
+
else:
|
|
753
|
+
dev_logger.error('Task failed', err, {'task_id': task.task_id})
|
|
754
|
+
|
|
755
|
+
error = get_error_for_channel(err)
|
|
756
|
+
message = {'status': 'ERROR', 'task_id': task.task_id, 'error': error['error']}
|
|
757
|
+
# Notify about this task failing, and a server broadcast error
|
|
758
|
+
await self._send_notification_for_pending_task(
|
|
759
|
+
pending_task=pending_task,
|
|
760
|
+
messages=[message, error],
|
|
761
|
+
)
|
|
762
|
+
# notify related tasks
|
|
763
|
+
await self._multicast_notification(
|
|
565
764
|
task_id=task.task_id,
|
|
566
|
-
|
|
567
|
-
cache_key=task.cache_key,
|
|
568
|
-
reg_entry=task.reg_entry,
|
|
765
|
+
messages=[message],
|
|
569
766
|
)
|
|
570
|
-
)
|
|
571
|
-
eng_logger.info(f'TaskManager finished task {task.task_id}', {'result': result})
|
|
572
|
-
except (Exception, ExceptionGroup) as err:
|
|
573
|
-
err = resolve_exception_group(err)
|
|
574
|
-
|
|
575
|
-
# Mark pending task as failed
|
|
576
|
-
self.tasks[task.task_id].fail(err)
|
|
577
767
|
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
# Notify any channels that need to be notified
|
|
582
|
-
await notify_channels({'status': 'ERROR', 'task_id': task.task_id}, get_error_for_channel())
|
|
583
|
-
finally:
|
|
584
|
-
# Remove the task from the running tasks
|
|
585
|
-
self.tasks.pop(task.task_id, None)
|
|
768
|
+
# Remove the task from the running tasks
|
|
769
|
+
self.tasks.pop(task.task_id, None)
|
|
586
770
|
|
|
587
771
|
# Make sure streams are closed
|
|
588
772
|
with move_on_after(3, shield=True):
|
|
@@ -590,7 +774,7 @@ class TaskManager:
|
|
|
590
774
|
await receive_stream.aclose()
|
|
591
775
|
|
|
592
776
|
|
|
593
|
-
TaskResultEntry = CachedRegistryEntry(uid='task-results', cache=
|
|
777
|
+
TaskResultEntry = CachedRegistryEntry(uid='task-results', cache=LruCachePolicy(max_size=256))
|
|
594
778
|
"""
|
|
595
779
|
Global registry entry for task results.
|
|
596
780
|
This is global because task ids are unique and accessed one time only so it's effectively a one-time use random key.
|