dara-core 1.19.0__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dara/core/__init__.py +1 -0
- dara/core/auth/basic.py +13 -7
- dara/core/auth/definitions.py +2 -2
- dara/core/auth/utils.py +1 -1
- dara/core/base_definitions.py +7 -42
- dara/core/data_utils.py +16 -17
- dara/core/definitions.py +8 -8
- dara/core/interactivity/__init__.py +6 -0
- dara/core/interactivity/actions.py +26 -22
- dara/core/interactivity/any_data_variable.py +7 -135
- dara/core/interactivity/any_variable.py +1 -1
- dara/core/interactivity/client_variable.py +71 -0
- dara/core/interactivity/data_variable.py +8 -266
- dara/core/interactivity/derived_data_variable.py +6 -290
- dara/core/interactivity/derived_variable.py +381 -201
- dara/core/interactivity/filtering.py +29 -2
- dara/core/interactivity/loop_variable.py +2 -2
- dara/core/interactivity/non_data_variable.py +5 -68
- dara/core/interactivity/plain_variable.py +87 -14
- dara/core/interactivity/server_variable.py +325 -0
- dara/core/interactivity/state_variable.py +69 -0
- dara/core/interactivity/switch_variable.py +15 -15
- dara/core/interactivity/tabular_variable.py +94 -0
- dara/core/interactivity/url_variable.py +10 -90
- dara/core/internal/cache_store/cache_store.py +5 -20
- dara/core/internal/dependency_resolution.py +27 -69
- dara/core/internal/devtools.py +10 -3
- dara/core/internal/execute_action.py +9 -3
- dara/core/internal/multi_resource_lock.py +70 -0
- dara/core/internal/normalization.py +0 -5
- dara/core/internal/pandas_utils.py +105 -3
- dara/core/internal/pool/definitions.py +1 -1
- dara/core/internal/pool/task_pool.py +9 -6
- dara/core/internal/pool/utils.py +19 -14
- dara/core/internal/registries.py +3 -2
- dara/core/internal/registry.py +1 -1
- dara/core/internal/registry_lookup.py +5 -3
- dara/core/internal/routing.py +52 -121
- dara/core/internal/store.py +2 -29
- dara/core/internal/tasks.py +372 -182
- dara/core/internal/utils.py +25 -3
- dara/core/internal/websocket.py +1 -1
- dara/core/js_tooling/js_utils.py +2 -0
- dara/core/logging.py +10 -6
- dara/core/persistence.py +26 -4
- dara/core/umd/dara.core.umd.js +1091 -1469
- dara/core/visual/dynamic_component.py +17 -13
- {dara_core-1.19.0.dist-info → dara_core-1.20.0.dist-info}/METADATA +11 -11
- {dara_core-1.19.0.dist-info → dara_core-1.20.0.dist-info}/RECORD +52 -47
- {dara_core-1.19.0.dist-info → dara_core-1.20.0.dist-info}/LICENSE +0 -0
- {dara_core-1.19.0.dist-info → dara_core-1.20.0.dist-info}/WHEEL +0 -0
- {dara_core-1.19.0.dist-info → dara_core-1.20.0.dist-info}/entry_points.txt +0 -0
dara/core/internal/tasks.py
CHANGED
|
@@ -19,24 +19,26 @@ import contextlib
|
|
|
19
19
|
import inspect
|
|
20
20
|
import math
|
|
21
21
|
from collections.abc import Awaitable
|
|
22
|
-
from typing import Any, Callable, Dict, List, Optional, Union, overload
|
|
22
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Union, overload
|
|
23
23
|
|
|
24
24
|
from anyio import (
|
|
25
|
+
BrokenResourceError,
|
|
25
26
|
CancelScope,
|
|
26
27
|
ClosedResourceError,
|
|
27
28
|
create_memory_object_stream,
|
|
28
29
|
create_task_group,
|
|
30
|
+
get_cancelled_exc_class,
|
|
29
31
|
move_on_after,
|
|
30
32
|
)
|
|
31
33
|
from anyio.abc import TaskGroup
|
|
32
34
|
from anyio.streams.memory import MemoryObjectSendStream
|
|
33
|
-
from exceptiongroup import ExceptionGroup
|
|
35
|
+
from exceptiongroup import ExceptionGroup, catch
|
|
34
36
|
from pydantic import ConfigDict
|
|
35
37
|
|
|
36
38
|
from dara.core.base_definitions import (
|
|
37
39
|
BaseTask,
|
|
38
|
-
Cache,
|
|
39
40
|
CachedRegistryEntry,
|
|
41
|
+
LruCachePolicy,
|
|
40
42
|
PendingTask,
|
|
41
43
|
TaskError,
|
|
42
44
|
TaskMessage,
|
|
@@ -47,7 +49,7 @@ from dara.core.internal.cache_store import CacheStore
|
|
|
47
49
|
from dara.core.internal.devtools import get_error_for_channel
|
|
48
50
|
from dara.core.internal.pandas_utils import remove_index
|
|
49
51
|
from dara.core.internal.pool import TaskPool
|
|
50
|
-
from dara.core.internal.utils import
|
|
52
|
+
from dara.core.internal.utils import exception_group_contains, run_user_handler
|
|
51
53
|
from dara.core.internal.websocket import WebsocketManager
|
|
52
54
|
from dara.core.logging import dev_logger, eng_logger
|
|
53
55
|
from dara.core.metrics import RUNTIME_METRICS_TRACKER
|
|
@@ -214,7 +216,7 @@ class MetaTask(BaseTask):
|
|
|
214
216
|
:param notify_channels: If this task is run in a TaskManager instance these channels will also be notified on
|
|
215
217
|
completion
|
|
216
218
|
:param process_as_task: Whether to run the process_result function as a task or not, defaults to False
|
|
217
|
-
:param cache_key: Optional cache key if there is a
|
|
219
|
+
:param cache_key: Optional cache key if there is a registry entry to store results for the task in
|
|
218
220
|
:param task_id: Optional task_id to set for the task - otherwise the task generates its id automatically
|
|
219
221
|
"""
|
|
220
222
|
self.args = args if args is not None else []
|
|
@@ -235,96 +237,122 @@ class MetaTask(BaseTask):
|
|
|
235
237
|
|
|
236
238
|
:param send_stream: The stream to send messages to the task manager on
|
|
237
239
|
"""
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
# Collect up the tasks that need to be run and kick them off without awaiting them.
|
|
241
|
-
tasks.extend(x for x in self.args if isinstance(x, BaseTask))
|
|
242
|
-
tasks.extend(x for x in self.kwargs.values() if isinstance(x, BaseTask))
|
|
243
|
-
|
|
244
|
-
eng_logger.info(f'MetaTask {self.task_id} running sub-tasks', {'task_ids': [x.task_id for x in tasks]})
|
|
240
|
+
self.cancel_scope = CancelScope()
|
|
245
241
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
async
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
242
|
+
try:
|
|
243
|
+
with self.cancel_scope:
|
|
244
|
+
tasks: List[BaseTask] = []
|
|
245
|
+
|
|
246
|
+
# Collect up the tasks that need to be run and kick them off without awaiting them.
|
|
247
|
+
tasks.extend(x for x in self.args if isinstance(x, BaseTask))
|
|
248
|
+
tasks.extend(x for x in self.kwargs.values() if isinstance(x, BaseTask))
|
|
249
|
+
|
|
250
|
+
eng_logger.info(f'MetaTask {self.task_id} running sub-tasks', {'task_ids': [x.task_id for x in tasks]})
|
|
251
|
+
|
|
252
|
+
# Wait for all tasks to complete
|
|
253
|
+
results: Dict[str, Any] = {}
|
|
254
|
+
|
|
255
|
+
async def _run_and_capture_result(task: BaseTask):
|
|
256
|
+
"""
|
|
257
|
+
Run a task and capture the result
|
|
258
|
+
"""
|
|
259
|
+
nonlocal results
|
|
260
|
+
result = await task.run(send_stream)
|
|
261
|
+
results[task.task_id] = result
|
|
262
|
+
|
|
263
|
+
def handle_exception(err: ExceptionGroup):
|
|
264
|
+
with contextlib.suppress(ClosedResourceError, BrokenResourceError):
|
|
265
|
+
if send_stream is not None:
|
|
266
|
+
send_stream.send_nowait(
|
|
267
|
+
TaskError(
|
|
268
|
+
task_id=self.task_id, error=err, cache_key=self.cache_key, reg_entry=self.reg_entry
|
|
269
|
+
)
|
|
270
|
+
)
|
|
271
|
+
raise err
|
|
272
|
+
|
|
273
|
+
if len(tasks) > 0:
|
|
274
|
+
with catch(
|
|
275
|
+
{
|
|
276
|
+
BaseException: handle_exception, # type: ignore
|
|
277
|
+
get_cancelled_exc_class(): handle_exception,
|
|
278
|
+
}
|
|
279
|
+
):
|
|
280
|
+
async with create_task_group() as tg:
|
|
281
|
+
for task in tasks:
|
|
282
|
+
tg.start_soon(_run_and_capture_result, task)
|
|
283
|
+
|
|
284
|
+
# In testing somehow sometimes cancellations aren't bubbled up, this ensures that's the case
|
|
285
|
+
if tg.cancel_scope.cancel_called:
|
|
286
|
+
raise get_cancelled_exc_class()('MetaTask caught cancellation')
|
|
287
|
+
|
|
288
|
+
eng_logger.debug(f'MetaTask {self.task_id}', 'completed sub-tasks', results)
|
|
289
|
+
|
|
290
|
+
# Order the results in the same order as the tasks list
|
|
291
|
+
result_values = [results[task.task_id] for task in tasks]
|
|
292
|
+
|
|
293
|
+
args = []
|
|
294
|
+
kwargs = {}
|
|
295
|
+
|
|
296
|
+
# Rebuild the args and kwargs with the results of the underlying tasks
|
|
297
|
+
# Here the task results could be DataFrames so make sure we clean the internal __index__ col from them
|
|
298
|
+
# before passing into the task function
|
|
299
|
+
for arg in self.args:
|
|
300
|
+
if isinstance(arg, BaseTask):
|
|
301
|
+
args.append(remove_index(result_values.pop(0)))
|
|
302
|
+
else:
|
|
303
|
+
args.append(arg)
|
|
304
|
+
|
|
305
|
+
for k, val in self.kwargs.items():
|
|
306
|
+
if isinstance(val, BaseTask):
|
|
307
|
+
kwargs[k] = remove_index(result_values.pop(0))
|
|
308
|
+
else:
|
|
309
|
+
kwargs[k] = val
|
|
310
|
+
|
|
311
|
+
eng_logger.debug(f'MetaTask {self.task_id}', 'processing result', {'args': args, 'kwargs': kwargs})
|
|
312
|
+
|
|
313
|
+
# Run the process result function with the completed set of args and kwargs
|
|
314
|
+
if self.process_as_task:
|
|
315
|
+
eng_logger.debug(f'MetaTask {self.task_id}', 'processing result as Task')
|
|
316
|
+
|
|
317
|
+
task = Task(
|
|
318
|
+
self.process_result,
|
|
319
|
+
args,
|
|
320
|
+
kwargs,
|
|
321
|
+
# Pass through cache_key so the processing task correctly updates the cache store entry
|
|
322
|
+
reg_entry=self.reg_entry,
|
|
323
|
+
cache_key=self.cache_key,
|
|
324
|
+
task_id=self.task_id,
|
|
267
325
|
)
|
|
268
|
-
|
|
269
|
-
finally:
|
|
270
|
-
self.cancel_scope = None
|
|
271
|
-
|
|
272
|
-
eng_logger.debug(f'MetaTask {self.task_id}', 'completed sub-tasks', results)
|
|
326
|
+
res = await task.run(send_stream)
|
|
273
327
|
|
|
274
|
-
|
|
275
|
-
result_values = [results[task.task_id] for task in tasks]
|
|
328
|
+
eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
|
|
276
329
|
|
|
277
|
-
|
|
278
|
-
kwargs = {}
|
|
330
|
+
return res
|
|
279
331
|
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
# before passing into the task function
|
|
283
|
-
for arg in self.args:
|
|
284
|
-
if isinstance(arg, BaseTask):
|
|
285
|
-
args.append(remove_index(result_values.pop(0)))
|
|
286
|
-
else:
|
|
287
|
-
args.append(arg)
|
|
288
|
-
|
|
289
|
-
for k, val in self.kwargs.items():
|
|
290
|
-
if isinstance(val, BaseTask):
|
|
291
|
-
kwargs[k] = remove_index(result_values.pop(0))
|
|
292
|
-
else:
|
|
293
|
-
kwargs[k] = val
|
|
294
|
-
|
|
295
|
-
eng_logger.debug(f'MetaTask {self.task_id}', 'processing result', {'args': args, 'kwargs': kwargs})
|
|
296
|
-
|
|
297
|
-
# Run the process result function with the completed set of args and kwargs
|
|
298
|
-
if self.process_as_task:
|
|
299
|
-
eng_logger.debug(f'MetaTask {self.task_id}', 'processing result as Task')
|
|
300
|
-
# Pass through cache_key so the processing task correctly updates the cache store entry
|
|
301
|
-
task = Task(self.process_result, args, kwargs, cache_key=self.cache_key)
|
|
302
|
-
res = await task.run(send_stream)
|
|
332
|
+
try:
|
|
333
|
+
res = await run_user_handler(self.process_result, args, kwargs)
|
|
303
334
|
|
|
304
|
-
|
|
335
|
+
# Send MetaTask result - it could be that there is a nested structure
|
|
336
|
+
# of MetaTasks so we need to make sure intermediate results are also sent
|
|
337
|
+
if send_stream is not None:
|
|
338
|
+
await send_stream.send(
|
|
339
|
+
TaskResult(
|
|
340
|
+
task_id=self.task_id, result=res, cache_key=self.cache_key, reg_entry=self.reg_entry
|
|
341
|
+
)
|
|
342
|
+
)
|
|
343
|
+
except BaseException as e:
|
|
344
|
+
# Recover from error - update the pending value to prevent subsequent requests getting stuck
|
|
345
|
+
if send_stream is not None:
|
|
346
|
+
await send_stream.send(
|
|
347
|
+
TaskError(task_id=self.task_id, error=e, cache_key=self.cache_key, reg_entry=self.reg_entry)
|
|
348
|
+
)
|
|
349
|
+
raise
|
|
305
350
|
|
|
306
|
-
|
|
351
|
+
eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
|
|
307
352
|
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
# Send MetaTask result - it could be that there is a nested structure
|
|
312
|
-
# of MetaTasks so we need to make sure intermediate results are also sent
|
|
313
|
-
if send_stream is not None:
|
|
314
|
-
await send_stream.send(
|
|
315
|
-
TaskResult(task_id=self.task_id, result=res, cache_key=self.cache_key, reg_entry=self.reg_entry)
|
|
316
|
-
)
|
|
317
|
-
except BaseException as e:
|
|
318
|
-
# Recover from error - update the pending value to prevent subsequent requests getting stuck
|
|
319
|
-
if send_stream is not None:
|
|
320
|
-
await send_stream.send(
|
|
321
|
-
TaskError(task_id=self.task_id, error=e, cache_key=self.cache_key, reg_entry=self.reg_entry)
|
|
322
|
-
)
|
|
323
|
-
raise
|
|
324
|
-
|
|
325
|
-
eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
|
|
326
|
-
|
|
327
|
-
return res
|
|
353
|
+
return res
|
|
354
|
+
finally:
|
|
355
|
+
self.cancel_scope = None
|
|
328
356
|
|
|
329
357
|
async def cancel(self):
|
|
330
358
|
"""
|
|
@@ -343,8 +371,10 @@ class TaskManager:
|
|
|
343
371
|
TaskManager is responsible for running tasks and managing their pending state. It is also responsible for
|
|
344
372
|
communicating the state of tasks to the client via the WebsocketManager.
|
|
345
373
|
|
|
346
|
-
|
|
347
|
-
|
|
374
|
+
Every task created gets registered with the TaskManager and is tracked by the TaskManager
|
|
375
|
+
as a PendingTask.
|
|
376
|
+
|
|
377
|
+
This allows the task to be retrieved by the cache_key from the store.
|
|
348
378
|
|
|
349
379
|
When a task is completed, it is removed from the tasks dict and the store entry is updated with the result.
|
|
350
380
|
|
|
@@ -357,6 +387,14 @@ class TaskManager:
|
|
|
357
387
|
self.ws_manager = ws_manager
|
|
358
388
|
self.store = store
|
|
359
389
|
|
|
390
|
+
def register_task(self, task: BaseTask) -> PendingTask:
|
|
391
|
+
"""
|
|
392
|
+
Register a task. This will ensure the task it tracked and notifications are routed correctly.
|
|
393
|
+
"""
|
|
394
|
+
pending_task = PendingTask(task.task_id, task)
|
|
395
|
+
self.tasks[task.task_id] = pending_task
|
|
396
|
+
return pending_task
|
|
397
|
+
|
|
360
398
|
@overload
|
|
361
399
|
async def run_task(self, task: PendingTask, ws_channel: Optional[str] = None) -> Any: ...
|
|
362
400
|
|
|
@@ -374,7 +412,9 @@ class TaskManager:
|
|
|
374
412
|
# append the websocket channel to the task
|
|
375
413
|
if isinstance(task, PendingTask):
|
|
376
414
|
if task.task_id in self.tasks:
|
|
377
|
-
|
|
415
|
+
# Increment subscriber count for this component request
|
|
416
|
+
self.tasks[task.task_id].add_subscriber()
|
|
417
|
+
if ws_channel is not None:
|
|
378
418
|
self.tasks[task.task_id].notify_channels.append(ws_channel)
|
|
379
419
|
return self.tasks[task.task_id]
|
|
380
420
|
|
|
@@ -388,21 +428,65 @@ class TaskManager:
|
|
|
388
428
|
else self.get_result(task.task_id)
|
|
389
429
|
)
|
|
390
430
|
|
|
391
|
-
#
|
|
392
|
-
pending_task =
|
|
393
|
-
if
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
self.tasks[task.task_id] = pending_task
|
|
431
|
+
# Otherwise, we should already have a pending task for this task
|
|
432
|
+
pending_task = self.tasks[task.task_id]
|
|
433
|
+
if ws_channel is not None:
|
|
434
|
+
pending_task.notify_channels.append(ws_channel)
|
|
397
435
|
|
|
398
436
|
# Run the task in the background
|
|
399
|
-
self.task_group.start_soon(self._run_task_and_notify, task
|
|
437
|
+
self.task_group.start_soon(self._run_task_and_notify, task)
|
|
400
438
|
|
|
401
439
|
return pending_task
|
|
402
440
|
|
|
441
|
+
async def _cancel_tasks(self, task_ids: List[str], notify: bool = True):
|
|
442
|
+
"""
|
|
443
|
+
Cancel a list of tasks
|
|
444
|
+
|
|
445
|
+
:param task_ids: The list of task IDs to cancel
|
|
446
|
+
:param notify: Whether to send cancellation notifications
|
|
447
|
+
"""
|
|
448
|
+
with CancelScope(shield=True):
|
|
449
|
+
# Cancel all tasks in the hierarchy
|
|
450
|
+
for task_id_to_cancel in task_ids:
|
|
451
|
+
if task_id_to_cancel in self.tasks:
|
|
452
|
+
pending_task = self.tasks[task_id_to_cancel]
|
|
453
|
+
|
|
454
|
+
# Notify channels that this specific task was cancelled
|
|
455
|
+
if notify:
|
|
456
|
+
await self._send_notification_for_pending_task(
|
|
457
|
+
pending_task=pending_task,
|
|
458
|
+
messages=[{'status': 'CANCELED', 'task_id': task_id_to_cancel}],
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
if not pending_task.event.is_set():
|
|
462
|
+
# Cancel the actual task
|
|
463
|
+
await pending_task.cancel()
|
|
464
|
+
|
|
465
|
+
# Remove from cache if it has cache settings
|
|
466
|
+
if pending_task.task_def.cache_key is not None and pending_task.task_def.reg_entry is not None:
|
|
467
|
+
await self.store.delete(
|
|
468
|
+
pending_task.task_def.reg_entry, key=pending_task.task_def.cache_key
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
# Remove from running tasks
|
|
472
|
+
self.tasks.pop(task_id_to_cancel, None)
|
|
473
|
+
|
|
474
|
+
dev_logger.info('Task cancelled', {'task_id': task_id_to_cancel})
|
|
475
|
+
|
|
476
|
+
async def _cancel_task_hierarchy(self, task: BaseTask, notify: bool = True):
|
|
477
|
+
"""
|
|
478
|
+
Recursively cancel all tasks in a task hierarchy
|
|
479
|
+
|
|
480
|
+
:param task: The root task to cancel (and its children)
|
|
481
|
+
:param notify: Whether to send cancellation notifications
|
|
482
|
+
"""
|
|
483
|
+
all_task_ids = self._collect_all_task_ids_in_hierarchy(task)
|
|
484
|
+
await self._cancel_tasks(list(all_task_ids), notify)
|
|
485
|
+
|
|
403
486
|
async def cancel_task(self, task_id: str, notify: bool = True):
|
|
404
487
|
"""
|
|
405
|
-
Cancel a running task by its id
|
|
488
|
+
Cancel a running task by its id. If the task has child tasks (MetaTask),
|
|
489
|
+
all child tasks will also be cancelled.
|
|
406
490
|
|
|
407
491
|
:param task_id: the id of the task
|
|
408
492
|
:param notify: whether to notify, true by default
|
|
@@ -417,22 +501,9 @@ class TaskManager:
|
|
|
417
501
|
task.remove_subscriber()
|
|
418
502
|
return
|
|
419
503
|
|
|
420
|
-
#
|
|
421
|
-
|
|
422
|
-
for channel in [*task.notify_channels, *task.task_def.notify_channels]:
|
|
423
|
-
await self.ws_manager.send_message(channel, {'status': 'CANCELED', 'task_id': task_id})
|
|
424
|
-
|
|
425
|
-
# We're only now cancelling the task to make sure the clients are notified about cancelling
|
|
426
|
-
# and receive the correct status rather than an error
|
|
427
|
-
await task.cancel()
|
|
428
|
-
|
|
429
|
-
# Then remove the pending task from cache so next requests would recalculate rather than receive
|
|
430
|
-
# a broken pending task
|
|
431
|
-
if task.task_def.cache_key is not None and task.task_def.reg_entry is not None:
|
|
432
|
-
await self.store.set(task.task_def.reg_entry, key=task.task_def.cache_key, value=None)
|
|
504
|
+
# Cancel the entire task hierarchy (including child tasks)
|
|
505
|
+
await self._cancel_task_hierarchy(task.task_def, notify)
|
|
433
506
|
|
|
434
|
-
# Remove from running tasks
|
|
435
|
-
self.tasks.pop(task_id, None)
|
|
436
507
|
else:
|
|
437
508
|
raise TaskManagerError('Could not find a task with the passed id to cancel.')
|
|
438
509
|
|
|
@@ -453,12 +524,9 @@ class TaskManager:
|
|
|
453
524
|
|
|
454
525
|
:param task_id: the id of the task to fetch
|
|
455
526
|
"""
|
|
456
|
-
result
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
await self.store.delete(TaskResultEntry, key=task_id)
|
|
460
|
-
|
|
461
|
-
return result
|
|
527
|
+
# the result is not deleted, the results are kept in an LRU cache
|
|
528
|
+
# which will clean up older entries
|
|
529
|
+
return await self.store.get(TaskResultEntry, key=task_id)
|
|
462
530
|
|
|
463
531
|
async def set_result(self, task_id: str, value: Any):
|
|
464
532
|
"""
|
|
@@ -466,7 +534,82 @@ class TaskManager:
|
|
|
466
534
|
"""
|
|
467
535
|
return await self.store.set(TaskResultEntry, key=task_id, value=value)
|
|
468
536
|
|
|
469
|
-
|
|
537
|
+
def _collect_all_task_ids_in_hierarchy(self, task: BaseTask) -> Set[str]:
|
|
538
|
+
"""
|
|
539
|
+
Recursively collect all task IDs in the task hierarchy
|
|
540
|
+
|
|
541
|
+
:param task: The root task to start collecting from
|
|
542
|
+
:return: Set of all task IDs in the hierarchy
|
|
543
|
+
"""
|
|
544
|
+
task_ids = {task.task_id}
|
|
545
|
+
|
|
546
|
+
if isinstance(task, MetaTask):
|
|
547
|
+
# Collect from args
|
|
548
|
+
for arg in task.args:
|
|
549
|
+
if isinstance(arg, BaseTask):
|
|
550
|
+
task_ids.update(self._collect_all_task_ids_in_hierarchy(arg))
|
|
551
|
+
|
|
552
|
+
# Collect from kwargs
|
|
553
|
+
for value in task.kwargs.values():
|
|
554
|
+
if isinstance(value, BaseTask):
|
|
555
|
+
task_ids.update(self._collect_all_task_ids_in_hierarchy(value))
|
|
556
|
+
|
|
557
|
+
return task_ids
|
|
558
|
+
|
|
559
|
+
async def _multicast_notification(self, task_id: str, messages: List[dict]):
|
|
560
|
+
"""
|
|
561
|
+
Send notifications to all task IDs that are related to a given task
|
|
562
|
+
|
|
563
|
+
:param task: the task the notifications are related to
|
|
564
|
+
:param messages: List of message dictionaries to send to all related tasks
|
|
565
|
+
"""
|
|
566
|
+
# prevent cancellation, we need the notifications to be sent
|
|
567
|
+
with CancelScope(shield=True):
|
|
568
|
+
# Find all PendingTasks that have the message_task_id in their hierarchy
|
|
569
|
+
tasks_to_notify = set()
|
|
570
|
+
|
|
571
|
+
for pending_task in self.tasks.values():
|
|
572
|
+
# Check if the message_task_id is in this PendingTask's hierarchy
|
|
573
|
+
task_ids_in_hierarchy = self._collect_all_task_ids_in_hierarchy(pending_task.task_def)
|
|
574
|
+
if task_id in task_ids_in_hierarchy:
|
|
575
|
+
tasks_to_notify.add(pending_task.task_id)
|
|
576
|
+
|
|
577
|
+
# Send notifications for all affected PendingTasks in parallel
|
|
578
|
+
if tasks_to_notify:
|
|
579
|
+
async with create_task_group() as task_tg:
|
|
580
|
+
for pending_task_id in tasks_to_notify:
|
|
581
|
+
if pending_task_id not in self.tasks:
|
|
582
|
+
continue
|
|
583
|
+
pending_task = self.tasks[pending_task_id]
|
|
584
|
+
task_tg.start_soon(self._send_notification_for_pending_task, pending_task, messages)
|
|
585
|
+
|
|
586
|
+
async def _send_notification_for_pending_task(self, pending_task: PendingTask, messages: List[dict]):
|
|
587
|
+
"""
|
|
588
|
+
Send notifications for a specific PendingTask
|
|
589
|
+
|
|
590
|
+
:param pending_task: The PendingTask to send notifications for
|
|
591
|
+
:param messages: The messages to send
|
|
592
|
+
"""
|
|
593
|
+
# Collect channels for this PendingTask
|
|
594
|
+
channels_to_notify = set(pending_task.notify_channels)
|
|
595
|
+
channels_to_notify.update(pending_task.task_def.notify_channels)
|
|
596
|
+
|
|
597
|
+
if not channels_to_notify:
|
|
598
|
+
return
|
|
599
|
+
|
|
600
|
+
# Send to all channels for this PendingTask in parallel
|
|
601
|
+
async def _send_to_channel(channel: str):
|
|
602
|
+
async with create_task_group() as channel_tg:
|
|
603
|
+
for message in messages:
|
|
604
|
+
# Create message with this PendingTask's task_id (if message has task_id)
|
|
605
|
+
message_for_task = {**message, 'task_id': pending_task.task_id} if 'task_id' in message else message
|
|
606
|
+
channel_tg.start_soon(self.ws_manager.send_message, channel, message_for_task)
|
|
607
|
+
|
|
608
|
+
async with create_task_group() as channel_tg:
|
|
609
|
+
for channel in channels_to_notify:
|
|
610
|
+
channel_tg.start_soon(_send_to_channel, channel)
|
|
611
|
+
|
|
612
|
+
async def _run_task_and_notify(self, task: BaseTask):
|
|
470
613
|
"""
|
|
471
614
|
Run the task to completion and notify the client of progress and completion
|
|
472
615
|
|
|
@@ -475,23 +618,13 @@ class TaskManager:
|
|
|
475
618
|
"""
|
|
476
619
|
cancel_scope = CancelScope()
|
|
477
620
|
|
|
478
|
-
self.tasks[task.task_id]
|
|
621
|
+
pending_task = self.tasks[task.task_id]
|
|
622
|
+
|
|
623
|
+
pending_task.cancel_scope = cancel_scope
|
|
479
624
|
|
|
480
625
|
with cancel_scope:
|
|
481
626
|
eng_logger.info(f'TaskManager running task {task.task_id}')
|
|
482
627
|
|
|
483
|
-
async def notify_channels(*messages: dict):
|
|
484
|
-
"""
|
|
485
|
-
Notify the channels of the task's progress
|
|
486
|
-
"""
|
|
487
|
-
channels_to_notify = [*task.notify_channels]
|
|
488
|
-
if ws_channel:
|
|
489
|
-
channels_to_notify.append(ws_channel)
|
|
490
|
-
|
|
491
|
-
for channel in channels_to_notify:
|
|
492
|
-
for message in messages:
|
|
493
|
-
await self.ws_manager.send_message(channel, message)
|
|
494
|
-
|
|
495
628
|
# Create a memory object stream to capture messages from the tasks
|
|
496
629
|
send_stream, receive_stream = create_memory_object_stream[TaskMessage](math.inf)
|
|
497
630
|
|
|
@@ -499,32 +632,41 @@ class TaskManager:
|
|
|
499
632
|
async with receive_stream:
|
|
500
633
|
async for message in receive_stream:
|
|
501
634
|
if isinstance(message, TaskProgressUpdate):
|
|
502
|
-
#
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
)
|
|
635
|
+
# Send progress notifications to related tasks
|
|
636
|
+
progress_message = {
|
|
637
|
+
'task_id': message.task_id, # Will be updated per task ID in multicast
|
|
638
|
+
'status': 'PROGRESS',
|
|
639
|
+
'progress': message.progress,
|
|
640
|
+
'message': message.message,
|
|
641
|
+
}
|
|
642
|
+
await self._multicast_notification(message.task_id, [progress_message])
|
|
511
643
|
if isinstance(task, Task) and task.on_progress:
|
|
512
644
|
await run_user_handler(task.on_progress, args=(message,))
|
|
513
645
|
elif isinstance(message, TaskResult):
|
|
514
|
-
#
|
|
646
|
+
# Handle dual coordination patterns:
|
|
647
|
+
# 1. Direct-coordinated tasks: resolve via active_tasks registry
|
|
515
648
|
if message.task_id in self.tasks:
|
|
516
|
-
self.tasks[
|
|
517
|
-
|
|
649
|
+
self.tasks[message.task_id].resolve(message.result)
|
|
650
|
+
|
|
651
|
+
# 2. Cache-coordinated tasks: resolve via cache store (CacheStore.set handles PendingTask resolution)
|
|
518
652
|
if (
|
|
519
653
|
message.cache_key is not None
|
|
520
654
|
and message.reg_entry is not None
|
|
521
655
|
and message.reg_entry.cache is not None
|
|
522
656
|
):
|
|
523
657
|
await self.store.set(message.reg_entry, key=message.cache_key, value=message.result)
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
658
|
+
|
|
659
|
+
# Set final result
|
|
660
|
+
await self.set_result(message.task_id, message.result)
|
|
661
|
+
|
|
662
|
+
# Notify all PendingTasks that depend on this specific task
|
|
663
|
+
await self._multicast_notification(
|
|
664
|
+
task_id=message.task_id,
|
|
665
|
+
messages=[{'result': message.result, 'status': 'COMPLETE', 'task_id': message.task_id}],
|
|
527
666
|
)
|
|
667
|
+
|
|
668
|
+
# Remove the task from the registered tasks - it finished running
|
|
669
|
+
self.tasks.pop(message.task_id, None)
|
|
528
670
|
elif isinstance(message, TaskError):
|
|
529
671
|
# Fail the pending task related to the error
|
|
530
672
|
if message.task_id in self.tasks:
|
|
@@ -537,46 +679,94 @@ class TaskManager:
|
|
|
537
679
|
and message.reg_entry is not None
|
|
538
680
|
and message.reg_entry.cache is not None
|
|
539
681
|
):
|
|
540
|
-
await self.store.
|
|
682
|
+
await self.store.delete(message.reg_entry, key=message.cache_key)
|
|
683
|
+
|
|
684
|
+
# Notify all PendingTasks that depend on this specific task
|
|
685
|
+
error = get_error_for_channel(message.error)
|
|
686
|
+
await self._multicast_notification(
|
|
687
|
+
task_id=message.task_id,
|
|
688
|
+
messages=[
|
|
689
|
+
{'status': 'ERROR', 'task_id': message.task_id, 'error': error['error']},
|
|
690
|
+
error,
|
|
691
|
+
],
|
|
692
|
+
)
|
|
541
693
|
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
# Handle incoming messages in parallel
|
|
545
|
-
tg.start_soon(handle_messages)
|
|
694
|
+
# Remove the task from the registered tasks - it finished running
|
|
695
|
+
self.tasks.pop(message.task_id, None)
|
|
546
696
|
|
|
547
|
-
|
|
548
|
-
async with send_stream:
|
|
549
|
-
result = task
|
|
550
|
-
while isinstance(result, BaseTask):
|
|
551
|
-
result = await task.run(send_stream)
|
|
697
|
+
task_error: Optional[ExceptionGroup] = None
|
|
552
698
|
|
|
553
|
-
|
|
554
|
-
|
|
699
|
+
# ExceptionGroup handler can't be async so we just mark the task as errored
|
|
700
|
+
# and run the async handler in the finally block
|
|
701
|
+
def handle_exception(err: ExceptionGroup):
|
|
702
|
+
nonlocal task_error
|
|
703
|
+
task_error = err
|
|
555
704
|
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
705
|
+
try:
|
|
706
|
+
with catch({BaseException: handle_exception}): # type: ignore
|
|
707
|
+
async with create_task_group() as tg:
|
|
708
|
+
# Handle incoming messages in parallel
|
|
709
|
+
tg.start_soon(handle_messages)
|
|
710
|
+
|
|
711
|
+
# Handle tasks that return other tasks
|
|
712
|
+
async with send_stream:
|
|
713
|
+
result = task
|
|
714
|
+
while isinstance(result, BaseTask):
|
|
715
|
+
result = await task.run(send_stream)
|
|
716
|
+
|
|
717
|
+
# Notify any channels that need to be notified about the whole task being completed
|
|
718
|
+
await send_stream.send(
|
|
719
|
+
TaskResult(
|
|
720
|
+
task_id=task.task_id,
|
|
721
|
+
result=result,
|
|
722
|
+
cache_key=task.cache_key,
|
|
723
|
+
reg_entry=task.reg_entry,
|
|
724
|
+
)
|
|
725
|
+
)
|
|
726
|
+
eng_logger.info(f'TaskManager finished task {task.task_id}', {'result': result})
|
|
727
|
+
finally:
|
|
728
|
+
with CancelScope(shield=True):
|
|
729
|
+
# pyright: ignore[reportUnreachable]
|
|
730
|
+
if task_error is not None:
|
|
731
|
+
err = task_error
|
|
732
|
+
# Mark pending task as failed
|
|
733
|
+
pending_task.fail(err)
|
|
734
|
+
|
|
735
|
+
await self.set_result(task.task_id, {'error': str(err)})
|
|
736
|
+
|
|
737
|
+
# If the task has a cache key, set cached value to None
|
|
738
|
+
# This makes it so that the next request will recalculate the value rather than keep failing
|
|
739
|
+
if (
|
|
740
|
+
task.cache_key is not None
|
|
741
|
+
and task.reg_entry is not None
|
|
742
|
+
and task.reg_entry.cache is not None
|
|
743
|
+
):
|
|
744
|
+
await self.store.delete(task.reg_entry, key=task.cache_key)
|
|
745
|
+
|
|
746
|
+
# If this is a cancellation, ensure all tasks in the hierarchy are cancelled
|
|
747
|
+
if exception_group_contains(err_type=get_cancelled_exc_class(), group=err):
|
|
748
|
+
dev_logger.info('Task cancelled', {'task_id': task.task_id})
|
|
749
|
+
|
|
750
|
+
# Cancel any remaining tasks in the hierarchy that might still be running
|
|
751
|
+
await self._cancel_task_hierarchy(task, notify=True)
|
|
752
|
+
else:
|
|
753
|
+
dev_logger.error('Task failed', err, {'task_id': task.task_id})
|
|
754
|
+
|
|
755
|
+
error = get_error_for_channel(err)
|
|
756
|
+
message = {'status': 'ERROR', 'task_id': task.task_id, 'error': error['error']}
|
|
757
|
+
# Notify about this task failing, and a server broadcast error
|
|
758
|
+
await self._send_notification_for_pending_task(
|
|
759
|
+
pending_task=pending_task,
|
|
760
|
+
messages=[message, error],
|
|
761
|
+
)
|
|
762
|
+
# notify related tasks
|
|
763
|
+
await self._multicast_notification(
|
|
559
764
|
task_id=task.task_id,
|
|
560
|
-
|
|
561
|
-
cache_key=task.cache_key,
|
|
562
|
-
reg_entry=task.reg_entry,
|
|
765
|
+
messages=[message],
|
|
563
766
|
)
|
|
564
|
-
)
|
|
565
|
-
eng_logger.info(f'TaskManager finished task {task.task_id}', {'result': result})
|
|
566
|
-
except (Exception, ExceptionGroup) as err:
|
|
567
|
-
err = resolve_exception_group(err)
|
|
568
|
-
|
|
569
|
-
# Mark pending task as failed
|
|
570
|
-
self.tasks[task.task_id].fail(err)
|
|
571
767
|
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
# Notify any channels that need to be notified
|
|
576
|
-
await notify_channels({'status': 'ERROR', 'task_id': task.task_id}, get_error_for_channel())
|
|
577
|
-
finally:
|
|
578
|
-
# Remove the task from the running tasks
|
|
579
|
-
self.tasks.pop(task.task_id, None)
|
|
768
|
+
# Remove the task from the running tasks
|
|
769
|
+
self.tasks.pop(task.task_id, None)
|
|
580
770
|
|
|
581
771
|
# Make sure streams are closed
|
|
582
772
|
with move_on_after(3, shield=True):
|
|
@@ -584,7 +774,7 @@ class TaskManager:
|
|
|
584
774
|
await receive_stream.aclose()
|
|
585
775
|
|
|
586
776
|
|
|
587
|
-
TaskResultEntry = CachedRegistryEntry(uid='task-results', cache=
|
|
777
|
+
TaskResultEntry = CachedRegistryEntry(uid='task-results', cache=LruCachePolicy(max_size=256))
|
|
588
778
|
"""
|
|
589
779
|
Global registry entry for task results.
|
|
590
780
|
This is global because task ids are unique and accessed one time only so it's effectively a one-time use random key.
|