dara-core 1.20.0a1__py3-none-any.whl → 1.20.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dara/core/__init__.py +0 -3
  2. dara/core/actions.py +2 -1
  3. dara/core/auth/basic.py +16 -22
  4. dara/core/auth/definitions.py +2 -2
  5. dara/core/auth/routes.py +5 -5
  6. dara/core/auth/utils.py +5 -5
  7. dara/core/base_definitions.py +64 -22
  8. dara/core/cli.py +7 -8
  9. dara/core/configuration.py +2 -5
  10. dara/core/css.py +2 -1
  11. dara/core/data_utils.py +19 -18
  12. dara/core/defaults.py +7 -6
  13. dara/core/definitions.py +19 -50
  14. dara/core/http.py +3 -7
  15. dara/core/interactivity/__init__.py +0 -6
  16. dara/core/interactivity/actions.py +50 -52
  17. dara/core/interactivity/any_data_variable.py +134 -7
  18. dara/core/interactivity/any_variable.py +8 -5
  19. dara/core/interactivity/data_variable.py +266 -8
  20. dara/core/interactivity/derived_data_variable.py +290 -7
  21. dara/core/interactivity/derived_variable.py +174 -414
  22. dara/core/interactivity/filtering.py +27 -46
  23. dara/core/interactivity/loop_variable.py +2 -2
  24. dara/core/interactivity/non_data_variable.py +68 -5
  25. dara/core/interactivity/plain_variable.py +15 -89
  26. dara/core/interactivity/switch_variable.py +19 -19
  27. dara/core/interactivity/url_variable.py +90 -10
  28. dara/core/internal/cache_store/base_impl.py +1 -2
  29. dara/core/internal/cache_store/cache_store.py +25 -22
  30. dara/core/internal/cache_store/keep_all.py +1 -4
  31. dara/core/internal/cache_store/lru.py +1 -5
  32. dara/core/internal/cache_store/ttl.py +1 -4
  33. dara/core/internal/cgroup.py +1 -1
  34. dara/core/internal/dependency_resolution.py +66 -60
  35. dara/core/internal/devtools.py +5 -12
  36. dara/core/internal/download.py +4 -13
  37. dara/core/internal/encoder_registry.py +7 -7
  38. dara/core/internal/execute_action.py +13 -13
  39. dara/core/internal/hashing.py +3 -1
  40. dara/core/internal/import_discovery.py +4 -3
  41. dara/core/internal/normalization.py +18 -9
  42. dara/core/internal/pandas_utils.py +5 -107
  43. dara/core/internal/pool/definitions.py +1 -1
  44. dara/core/internal/pool/task_pool.py +16 -25
  45. dara/core/internal/pool/utils.py +18 -21
  46. dara/core/internal/pool/worker.py +2 -3
  47. dara/core/internal/port_utils.py +1 -1
  48. dara/core/internal/registries.py +6 -12
  49. dara/core/internal/registry.py +2 -4
  50. dara/core/internal/registry_lookup.py +5 -11
  51. dara/core/internal/routing.py +145 -109
  52. dara/core/internal/scheduler.py +8 -13
  53. dara/core/internal/settings.py +2 -2
  54. dara/core/internal/store.py +29 -2
  55. dara/core/internal/tasks.py +195 -379
  56. dara/core/internal/utils.py +13 -36
  57. dara/core/internal/websocket.py +20 -21
  58. dara/core/js_tooling/js_utils.py +26 -28
  59. dara/core/js_tooling/templates/vite.config.template.ts +3 -12
  60. dara/core/logging.py +12 -13
  61. dara/core/main.py +11 -14
  62. dara/core/metrics/cache.py +1 -1
  63. dara/core/metrics/utils.py +3 -3
  64. dara/core/persistence.py +5 -27
  65. dara/core/umd/dara.core.umd.js +55425 -59091
  66. dara/core/visual/components/__init__.py +2 -2
  67. dara/core/visual/components/fallback.py +4 -30
  68. dara/core/visual/components/for_cmp.py +1 -4
  69. dara/core/visual/css/__init__.py +31 -30
  70. dara/core/visual/dynamic_component.py +28 -31
  71. dara/core/visual/progress_updater.py +3 -4
  72. {dara_core-1.20.0a1.dist-info → dara_core-1.20.1a1.dist-info}/METADATA +11 -12
  73. dara_core-1.20.1a1.dist-info/RECORD +114 -0
  74. dara/core/interactivity/client_variable.py +0 -71
  75. dara/core/interactivity/server_variable.py +0 -325
  76. dara/core/interactivity/state_variable.py +0 -69
  77. dara/core/interactivity/tabular_variable.py +0 -94
  78. dara/core/internal/multi_resource_lock.py +0 -70
  79. dara_core-1.20.0a1.dist-info/RECORD +0 -119
  80. {dara_core-1.20.0a1.dist-info → dara_core-1.20.1a1.dist-info}/LICENSE +0 -0
  81. {dara_core-1.20.0a1.dist-info → dara_core-1.20.1a1.dist-info}/WHEEL +0 -0
  82. {dara_core-1.20.0a1.dist-info → dara_core-1.20.1a1.dist-info}/entry_points.txt +0 -0
@@ -15,30 +15,26 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
- import contextlib
19
18
  import inspect
20
19
  import math
21
- from collections.abc import Awaitable
22
- from typing import Any, Callable, Dict, List, Optional, Set, Union, overload
20
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, Union, overload
23
21
 
24
22
  from anyio import (
25
- BrokenResourceError,
26
23
  CancelScope,
27
24
  ClosedResourceError,
28
25
  create_memory_object_stream,
29
26
  create_task_group,
30
- get_cancelled_exc_class,
31
27
  move_on_after,
32
28
  )
33
29
  from anyio.abc import TaskGroup
34
30
  from anyio.streams.memory import MemoryObjectSendStream
35
- from exceptiongroup import ExceptionGroup, catch
31
+ from exceptiongroup import ExceptionGroup
36
32
  from pydantic import ConfigDict
37
33
 
38
34
  from dara.core.base_definitions import (
39
35
  BaseTask,
36
+ Cache,
40
37
  CachedRegistryEntry,
41
- LruCachePolicy,
42
38
  PendingTask,
43
39
  TaskError,
44
40
  TaskMessage,
@@ -49,7 +45,7 @@ from dara.core.internal.cache_store import CacheStore
49
45
  from dara.core.internal.devtools import get_error_for_channel
50
46
  from dara.core.internal.pandas_utils import remove_index
51
47
  from dara.core.internal.pool import TaskPool
52
- from dara.core.internal.utils import exception_group_contains, run_user_handler
48
+ from dara.core.internal.utils import resolve_exception_group, run_user_handler
53
49
  from dara.core.internal.websocket import WebsocketManager
54
50
  from dara.core.logging import dev_logger, eng_logger
55
51
  from dara.core.metrics import RUNTIME_METRICS_TRACKER
@@ -141,12 +137,14 @@ class Task(BaseTask):
141
137
 
142
138
  async def on_progress(progress: float, msg: str):
143
139
  if send_stream is not None:
144
- with contextlib.suppress(ClosedResourceError):
140
+ try:
145
141
  await send_stream.send(TaskProgressUpdate(task_id=self.task_id, progress=progress, message=msg))
142
+ except ClosedResourceError:
143
+ pass
146
144
 
147
145
  async def on_result(result: Any):
148
146
  if send_stream is not None:
149
- with contextlib.suppress(ClosedResourceError):
147
+ try:
150
148
  await send_stream.send(
151
149
  TaskResult(
152
150
  task_id=self.task_id,
@@ -155,15 +153,19 @@ class Task(BaseTask):
155
153
  reg_entry=self.reg_entry,
156
154
  )
157
155
  )
156
+ except ClosedResourceError:
157
+ pass
158
158
 
159
159
  async def on_error(exc: BaseException):
160
160
  if send_stream is not None:
161
- with contextlib.suppress(ClosedResourceError):
161
+ try:
162
162
  await send_stream.send(
163
163
  TaskError(
164
164
  task_id=self.task_id, error=exc, cache_key=self.cache_key, reg_entry=self.reg_entry
165
165
  )
166
166
  )
167
+ except ClosedResourceError:
168
+ pass
167
169
 
168
170
  with pool.on_progress(self.task_id, on_progress):
169
171
  pool_task_def = pool.submit(self.task_id, self._func_name, args=tuple(self._args), kwargs=self._kwargs)
@@ -216,7 +218,7 @@ class MetaTask(BaseTask):
216
218
  :param notify_channels: If this task is run in a TaskManager instance these channels will also be notified on
217
219
  completion
218
220
  :param process_as_task: Whether to run the process_result function as a task or not, defaults to False
219
- :param cache_key: Optional cache key if there is a registry entry to store results for the task in
221
+ :param cache_key: Optional cache key if there is a PendingTask in the store associated with this task
220
222
  :param task_id: Optional task_id to set for the task - otherwise the task generates its id automatically
221
223
  """
222
224
  self.args = args if args is not None else []
@@ -237,122 +239,96 @@ class MetaTask(BaseTask):
237
239
 
238
240
  :param send_stream: The stream to send messages to the task manager on
239
241
  """
240
- self.cancel_scope = CancelScope()
242
+ tasks: List[BaseTask] = []
241
243
 
242
- try:
243
- with self.cancel_scope:
244
- tasks: List[BaseTask] = []
245
-
246
- # Collect up the tasks that need to be run and kick them off without awaiting them.
247
- tasks.extend(x for x in self.args if isinstance(x, BaseTask))
248
- tasks.extend(x for x in self.kwargs.values() if isinstance(x, BaseTask))
249
-
250
- eng_logger.info(f'MetaTask {self.task_id} running sub-tasks', {'task_ids': [x.task_id for x in tasks]})
251
-
252
- # Wait for all tasks to complete
253
- results: Dict[str, Any] = {}
254
-
255
- async def _run_and_capture_result(task: BaseTask):
256
- """
257
- Run a task and capture the result
258
- """
259
- nonlocal results
260
- result = await task.run(send_stream)
261
- results[task.task_id] = result
262
-
263
- def handle_exception(err: ExceptionGroup):
264
- with contextlib.suppress(ClosedResourceError, BrokenResourceError):
265
- if send_stream is not None:
266
- send_stream.send_nowait(
267
- TaskError(
268
- task_id=self.task_id, error=err, cache_key=self.cache_key, reg_entry=self.reg_entry
269
- )
270
- )
271
- raise err
272
-
273
- if len(tasks) > 0:
274
- with catch(
275
- {
276
- BaseException: handle_exception, # type: ignore
277
- get_cancelled_exc_class(): handle_exception,
278
- }
279
- ):
280
- async with create_task_group() as tg:
281
- for task in tasks:
282
- tg.start_soon(_run_and_capture_result, task)
283
-
284
- # In testing somehow sometimes cancellations aren't bubbled up, this ensures that's the case
285
- if tg.cancel_scope.cancel_called:
286
- raise get_cancelled_exc_class()('MetaTask caught cancellation')
287
-
288
- eng_logger.debug(f'MetaTask {self.task_id}', 'completed sub-tasks', results)
289
-
290
- # Order the results in the same order as the tasks list
291
- result_values = [results[task.task_id] for task in tasks]
292
-
293
- args = []
294
- kwargs = {}
295
-
296
- # Rebuild the args and kwargs with the results of the underlying tasks
297
- # Here the task results could be DataFrames so make sure we clean the internal __index__ col from them
298
- # before passing into the task function
299
- for arg in self.args:
300
- if isinstance(arg, BaseTask):
301
- args.append(remove_index(result_values.pop(0)))
302
- else:
303
- args.append(arg)
304
-
305
- for k, val in self.kwargs.items():
306
- if isinstance(val, BaseTask):
307
- kwargs[k] = remove_index(result_values.pop(0))
308
- else:
309
- kwargs[k] = val
310
-
311
- eng_logger.debug(f'MetaTask {self.task_id}', 'processing result', {'args': args, 'kwargs': kwargs})
312
-
313
- # Run the process result function with the completed set of args and kwargs
314
- if self.process_as_task:
315
- eng_logger.debug(f'MetaTask {self.task_id}', 'processing result as Task')
316
-
317
- task = Task(
318
- self.process_result,
319
- args,
320
- kwargs,
321
- # Pass through cache_key so the processing task correctly updates the cache store entry
322
- reg_entry=self.reg_entry,
323
- cache_key=self.cache_key,
324
- task_id=self.task_id,
244
+ # Collect up the tasks that need to be run and kick them off without awaiting them.
245
+ tasks.extend(x for x in self.args if isinstance(x, BaseTask))
246
+ tasks.extend(x for x in self.kwargs.values() if isinstance(x, BaseTask))
247
+
248
+ eng_logger.info(f'MetaTask {self.task_id} running sub-tasks', {'task_ids': [x.task_id for x in tasks]})
249
+
250
+ # Wait for all tasks to complete
251
+ results: Dict[str, Any] = {}
252
+
253
+ async def _run_and_capture_result(task: BaseTask):
254
+ """
255
+ Run a task and capture the result
256
+ """
257
+ nonlocal results
258
+ result = await task.run(send_stream)
259
+ results[task.task_id] = result
260
+
261
+ if len(tasks) > 0:
262
+ try:
263
+ async with create_task_group() as tg:
264
+ self.cancel_scope = tg.cancel_scope
265
+ for task in tasks:
266
+ tg.start_soon(_run_and_capture_result, task)
267
+ except BaseException as e:
268
+ if send_stream is not None:
269
+ await send_stream.send(
270
+ TaskError(task_id=self.task_id, error=e, cache_key=self.cache_key, reg_entry=self.reg_entry)
325
271
  )
326
- res = await task.run(send_stream)
272
+ raise
273
+ finally:
274
+ self.cancel_scope = None
327
275
 
328
- eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
276
+ eng_logger.debug(f'MetaTask {self.task_id}', 'completed sub-tasks', results)
329
277
 
330
- return res
278
+ # Order the results in the same order as the tasks list
279
+ result_values = [results[task.task_id] for task in tasks]
331
280
 
332
- try:
333
- res = await run_user_handler(self.process_result, args, kwargs)
281
+ args = []
282
+ kwargs = {}
334
283
 
335
- # Send MetaTask result - it could be that there is a nested structure
336
- # of MetaTasks so we need to make sure intermediate results are also sent
337
- if send_stream is not None:
338
- await send_stream.send(
339
- TaskResult(
340
- task_id=self.task_id, result=res, cache_key=self.cache_key, reg_entry=self.reg_entry
341
- )
342
- )
343
- except BaseException as e:
344
- # Recover from error - update the pending value to prevent subsequent requests getting stuck
345
- if send_stream is not None:
346
- await send_stream.send(
347
- TaskError(task_id=self.task_id, error=e, cache_key=self.cache_key, reg_entry=self.reg_entry)
348
- )
349
- raise
284
+ # Rebuild the args and kwargs with the results of the underlying tasks
285
+ # Here the task results could be DataFrames so make sure we clean the internal __index__ col from them
286
+ # before passing into the task function
287
+ for arg in self.args:
288
+ if isinstance(arg, BaseTask):
289
+ args.append(remove_index(result_values.pop(0)))
290
+ else:
291
+ args.append(arg)
292
+
293
+ for k, val in self.kwargs.items():
294
+ if isinstance(val, BaseTask):
295
+ kwargs[k] = remove_index(result_values.pop(0))
296
+ else:
297
+ kwargs[k] = val
298
+
299
+ eng_logger.debug(f'MetaTask {self.task_id}', 'processing result', {'args': args, 'kwargs': kwargs})
350
300
 
351
- eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
301
+ # Run the process result function with the completed set of args and kwargs
302
+ if self.process_as_task:
303
+ eng_logger.debug(f'MetaTask {self.task_id}', 'processing result as Task')
304
+ # Pass through cache_key so the processing task correctly updates the cache store entry
305
+ task = Task(self.process_result, args, kwargs, cache_key=self.cache_key)
306
+ res = await task.run(send_stream)
352
307
 
353
- return res
354
- finally:
355
- self.cancel_scope = None
308
+ eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
309
+
310
+ return res
311
+
312
+ try:
313
+ res = await run_user_handler(self.process_result, args, kwargs)
314
+
315
+ # Send MetaTask result - it could be that there is a nested structure
316
+ # of MetaTasks so we need to make sure intermediate results are also sent
317
+ if send_stream is not None:
318
+ await send_stream.send(
319
+ TaskResult(task_id=self.task_id, result=res, cache_key=self.cache_key, reg_entry=self.reg_entry)
320
+ )
321
+ except BaseException as e:
322
+ # Recover from error - update the pending value to prevent subsequent requests getting stuck
323
+ if send_stream is not None:
324
+ await send_stream.send(
325
+ TaskError(task_id=self.task_id, error=e, cache_key=self.cache_key, reg_entry=self.reg_entry)
326
+ )
327
+ raise
328
+
329
+ eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
330
+
331
+ return res
356
332
 
357
333
  async def cancel(self):
358
334
  """
@@ -371,10 +347,8 @@ class TaskManager:
371
347
  TaskManager is responsible for running tasks and managing their pending state. It is also responsible for
372
348
  communicating the state of tasks to the client via the WebsocketManager.
373
349
 
374
- Every task created gets registered with the TaskManager and is tracked by the TaskManager
375
- as a PendingTask.
376
-
377
- This allows the task to be retrieved by the cache_key from the store.
350
+ When a task is run, a PendingTask it is stored in the tasks dict. It is also stored in the store
351
+ with the key of the task's cache_key. This allows the task to be retrieved by the cache_key from the store.
378
352
 
379
353
  When a task is completed, it is removed from the tasks dict and the store entry is updated with the result.
380
354
 
@@ -387,19 +361,13 @@ class TaskManager:
387
361
  self.ws_manager = ws_manager
388
362
  self.store = store
389
363
 
390
- def register_task(self, task: BaseTask) -> PendingTask:
391
- """
392
- Register a task. This will ensure the task it tracked and notifications are routed correctly.
393
- """
394
- pending_task = PendingTask(task.task_id, task)
395
- self.tasks[task.task_id] = pending_task
396
- return pending_task
397
-
398
364
  @overload
399
- async def run_task(self, task: PendingTask, ws_channel: Optional[str] = None) -> Any: ...
365
+ async def run_task(self, task: PendingTask, ws_channel: Optional[str] = None) -> Any:
366
+ ...
400
367
 
401
368
  @overload
402
- async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None) -> PendingTask: ...
369
+ async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None) -> PendingTask:
370
+ ...
403
371
 
404
372
  async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None):
405
373
  """
@@ -412,9 +380,7 @@ class TaskManager:
412
380
  # append the websocket channel to the task
413
381
  if isinstance(task, PendingTask):
414
382
  if task.task_id in self.tasks:
415
- # Increment subscriber count for this component request
416
- self.tasks[task.task_id].add_subscriber()
417
- if ws_channel is not None:
383
+ if ws_channel:
418
384
  self.tasks[task.task_id].notify_channels.append(ws_channel)
419
385
  return self.tasks[task.task_id]
420
386
 
@@ -428,65 +394,21 @@ class TaskManager:
428
394
  else self.get_result(task.task_id)
429
395
  )
430
396
 
431
- # Otherwise, we should already have a pending task for this task
432
- pending_task = self.tasks[task.task_id]
433
- if ws_channel is not None:
434
- pending_task.notify_channels.append(ws_channel)
397
+ # Create and store the pending task
398
+ pending_task = PendingTask(task.task_id, task, ws_channel)
399
+ if task.cache_key is not None and task.reg_entry is not None:
400
+ await self.store.set(task.reg_entry, key=task.cache_key, value=pending_task)
401
+
402
+ self.tasks[task.task_id] = pending_task
435
403
 
436
404
  # Run the task in the background
437
- self.task_group.start_soon(self._run_task_and_notify, task)
405
+ self.task_group.start_soon(self._run_task_and_notify, task, ws_channel)
438
406
 
439
407
  return pending_task
440
408
 
441
- async def _cancel_tasks(self, task_ids: List[str], notify: bool = True):
442
- """
443
- Cancel a list of tasks
444
-
445
- :param task_ids: The list of task IDs to cancel
446
- :param notify: Whether to send cancellation notifications
447
- """
448
- with CancelScope(shield=True):
449
- # Cancel all tasks in the hierarchy
450
- for task_id_to_cancel in task_ids:
451
- if task_id_to_cancel in self.tasks:
452
- pending_task = self.tasks[task_id_to_cancel]
453
-
454
- # Notify channels that this specific task was cancelled
455
- if notify:
456
- await self._send_notification_for_pending_task(
457
- pending_task=pending_task,
458
- messages=[{'status': 'CANCELED', 'task_id': task_id_to_cancel}],
459
- )
460
-
461
- if not pending_task.event.is_set():
462
- # Cancel the actual task
463
- await pending_task.cancel()
464
-
465
- # Remove from cache if it has cache settings
466
- if pending_task.task_def.cache_key is not None and pending_task.task_def.reg_entry is not None:
467
- await self.store.delete(
468
- pending_task.task_def.reg_entry, key=pending_task.task_def.cache_key
469
- )
470
-
471
- # Remove from running tasks
472
- self.tasks.pop(task_id_to_cancel, None)
473
-
474
- dev_logger.info('Task cancelled', {'task_id': task_id_to_cancel})
475
-
476
- async def _cancel_task_hierarchy(self, task: BaseTask, notify: bool = True):
477
- """
478
- Recursively cancel all tasks in a task hierarchy
479
-
480
- :param task: The root task to cancel (and its children)
481
- :param notify: Whether to send cancellation notifications
482
- """
483
- all_task_ids = self._collect_all_task_ids_in_hierarchy(task)
484
- await self._cancel_tasks(list(all_task_ids), notify)
485
-
486
409
  async def cancel_task(self, task_id: str, notify: bool = True):
487
410
  """
488
- Cancel a running task by its id. If the task has child tasks (MetaTask),
489
- all child tasks will also be cancelled.
411
+ Cancel a running task by its id
490
412
 
491
413
  :param task_id: the id of the task
492
414
  :param notify: whether to notify, true by default
@@ -501,9 +423,22 @@ class TaskManager:
501
423
  task.remove_subscriber()
502
424
  return
503
425
 
504
- # Cancel the entire task hierarchy (including child tasks)
505
- await self._cancel_task_hierarchy(task.task_def, notify)
426
+ # Notify any listening channels that the job has been cancelled so that they can handle it correctly
427
+ if notify:
428
+ for channel in [*task.notify_channels, *task.task_def.notify_channels]:
429
+ await self.ws_manager.send_message(channel, {'status': 'CANCELED', 'task_id': task_id})
430
+
431
+ # We're only now cancelling the task to make sure the clients are notified about cancelling
432
+ # and receive the correct status rather than an error
433
+ await task.cancel()
434
+
435
+ # Then remove the pending task from cache so next requests would recalculate rather than receive
436
+ # a broken pending task
437
+ if task.task_def.cache_key is not None and task.task_def.reg_entry is not None:
438
+ await self.store.set(task.task_def.reg_entry, key=task.task_def.cache_key, value=None)
506
439
 
440
+ # Remove from running tasks
441
+ self.tasks.pop(task_id, None)
507
442
  else:
508
443
  raise TaskManagerError('Could not find a task with the passed id to cancel.')
509
444
 
@@ -524,9 +459,12 @@ class TaskManager:
524
459
 
525
460
  :param task_id: the id of the task to fetch
526
461
  """
527
- # the result is not deleted, the results are kept in an LRU cache
528
- # which will clean up older entries
529
- return await self.store.get(TaskResultEntry, key=task_id)
462
+ result = await self.store.get(TaskResultEntry, key=task_id)
463
+
464
+ # Clean up the result afterwards
465
+ await self.store.delete(TaskResultEntry, key=task_id)
466
+
467
+ return result
530
468
 
531
469
  async def set_result(self, task_id: str, value: Any):
532
470
  """
@@ -534,82 +472,7 @@ class TaskManager:
534
472
  """
535
473
  return await self.store.set(TaskResultEntry, key=task_id, value=value)
536
474
 
537
- def _collect_all_task_ids_in_hierarchy(self, task: BaseTask) -> Set[str]:
538
- """
539
- Recursively collect all task IDs in the task hierarchy
540
-
541
- :param task: The root task to start collecting from
542
- :return: Set of all task IDs in the hierarchy
543
- """
544
- task_ids = {task.task_id}
545
-
546
- if isinstance(task, MetaTask):
547
- # Collect from args
548
- for arg in task.args:
549
- if isinstance(arg, BaseTask):
550
- task_ids.update(self._collect_all_task_ids_in_hierarchy(arg))
551
-
552
- # Collect from kwargs
553
- for value in task.kwargs.values():
554
- if isinstance(value, BaseTask):
555
- task_ids.update(self._collect_all_task_ids_in_hierarchy(value))
556
-
557
- return task_ids
558
-
559
- async def _multicast_notification(self, task_id: str, messages: List[dict]):
560
- """
561
- Send notifications to all task IDs that are related to a given task
562
-
563
- :param task: the task the notifications are related to
564
- :param messages: List of message dictionaries to send to all related tasks
565
- """
566
- # prevent cancellation, we need the notifications to be sent
567
- with CancelScope(shield=True):
568
- # Find all PendingTasks that have the message_task_id in their hierarchy
569
- tasks_to_notify = set()
570
-
571
- for pending_task in self.tasks.values():
572
- # Check if the message_task_id is in this PendingTask's hierarchy
573
- task_ids_in_hierarchy = self._collect_all_task_ids_in_hierarchy(pending_task.task_def)
574
- if task_id in task_ids_in_hierarchy:
575
- tasks_to_notify.add(pending_task.task_id)
576
-
577
- # Send notifications for all affected PendingTasks in parallel
578
- if tasks_to_notify:
579
- async with create_task_group() as task_tg:
580
- for pending_task_id in tasks_to_notify:
581
- if pending_task_id not in self.tasks:
582
- continue
583
- pending_task = self.tasks[pending_task_id]
584
- task_tg.start_soon(self._send_notification_for_pending_task, pending_task, messages)
585
-
586
- async def _send_notification_for_pending_task(self, pending_task: PendingTask, messages: List[dict]):
587
- """
588
- Send notifications for a specific PendingTask
589
-
590
- :param pending_task: The PendingTask to send notifications for
591
- :param messages: The messages to send
592
- """
593
- # Collect channels for this PendingTask
594
- channels_to_notify = set(pending_task.notify_channels)
595
- channels_to_notify.update(pending_task.task_def.notify_channels)
596
-
597
- if not channels_to_notify:
598
- return
599
-
600
- # Send to all channels for this PendingTask in parallel
601
- async def _send_to_channel(channel: str):
602
- async with create_task_group() as channel_tg:
603
- for message in messages:
604
- # Create message with this PendingTask's task_id (if message has task_id)
605
- message_for_task = {**message, 'task_id': pending_task.task_id} if 'task_id' in message else message
606
- channel_tg.start_soon(self.ws_manager.send_message, channel, message_for_task)
607
-
608
- async with create_task_group() as channel_tg:
609
- for channel in channels_to_notify:
610
- channel_tg.start_soon(_send_to_channel, channel)
611
-
612
- async def _run_task_and_notify(self, task: BaseTask):
475
+ async def _run_task_and_notify(self, task: BaseTask, ws_channel: Optional[str]):
613
476
  """
614
477
  Run the task to completion and notify the client of progress and completion
615
478
 
@@ -618,13 +481,23 @@ class TaskManager:
618
481
  """
619
482
  cancel_scope = CancelScope()
620
483
 
621
- pending_task = self.tasks[task.task_id]
622
-
623
- pending_task.cancel_scope = cancel_scope
484
+ self.tasks[task.task_id].cancel_scope = cancel_scope
624
485
 
625
486
  with cancel_scope:
626
487
  eng_logger.info(f'TaskManager running task {task.task_id}')
627
488
 
489
+ async def notify_channels(*messages: dict):
490
+ """
491
+ Notify the channels of the task's progress
492
+ """
493
+ channels_to_notify = [*task.notify_channels]
494
+ if ws_channel:
495
+ channels_to_notify.append(ws_channel)
496
+
497
+ for channel in channels_to_notify:
498
+ for message in messages:
499
+ await self.ws_manager.send_message(channel, message)
500
+
628
501
  # Create a memory object stream to capture messages from the tasks
629
502
  send_stream, receive_stream = create_memory_object_stream[TaskMessage](math.inf)
630
503
 
@@ -632,41 +505,32 @@ class TaskManager:
632
505
  async with receive_stream:
633
506
  async for message in receive_stream:
634
507
  if isinstance(message, TaskProgressUpdate):
635
- # Send progress notifications to related tasks
636
- progress_message = {
637
- 'task_id': message.task_id, # Will be updated per task ID in multicast
638
- 'status': 'PROGRESS',
639
- 'progress': message.progress,
640
- 'message': message.message,
641
- }
642
- await self._multicast_notification(message.task_id, [progress_message])
508
+ # Notify the channels of the task's progress
509
+ await notify_channels(
510
+ {
511
+ 'task_id': task.task_id,
512
+ 'status': 'PROGRESS',
513
+ 'progress': message.progress,
514
+ 'message': message.message,
515
+ }
516
+ )
643
517
  if isinstance(task, Task) and task.on_progress:
644
518
  await run_user_handler(task.on_progress, args=(message,))
645
519
  elif isinstance(message, TaskResult):
646
- # Handle dual coordination patterns:
647
- # 1. Direct-coordinated tasks: resolve via active_tasks registry
520
+ # Resolve the pending task related to the result
648
521
  if message.task_id in self.tasks:
649
- self.tasks[message.task_id].resolve(message.result)
650
-
651
- # 2. Cache-coordinated tasks: resolve via cache store (CacheStore.set handles PendingTask resolution)
522
+ self.tasks[task.task_id].resolve(message.result)
523
+ # If the task has a cache key, update the cached value
652
524
  if (
653
525
  message.cache_key is not None
654
526
  and message.reg_entry is not None
655
527
  and message.reg_entry.cache is not None
656
528
  ):
657
529
  await self.store.set(message.reg_entry, key=message.cache_key, value=message.result)
658
-
659
- # Set final result
660
- await self.set_result(message.task_id, message.result)
661
-
662
- # Notify all PendingTasks that depend on this specific task
663
- await self._multicast_notification(
664
- task_id=message.task_id,
665
- messages=[{'result': message.result, 'status': 'COMPLETE', 'task_id': message.task_id}],
530
+ # Notify the channels of the task's completion
531
+ await notify_channels(
532
+ {'result': message.result, 'status': 'COMPLETE', 'task_id': message.task_id}
666
533
  )
667
-
668
- # Remove the task from the registered tasks - it finished running
669
- self.tasks.pop(message.task_id, None)
670
534
  elif isinstance(message, TaskError):
671
535
  # Fail the pending task related to the error
672
536
  if message.task_id in self.tasks:
@@ -679,94 +543,46 @@ class TaskManager:
679
543
  and message.reg_entry is not None
680
544
  and message.reg_entry.cache is not None
681
545
  ):
682
- await self.store.delete(message.reg_entry, key=message.cache_key)
683
-
684
- # Notify all PendingTasks that depend on this specific task
685
- error = get_error_for_channel(message.error)
686
- await self._multicast_notification(
687
- task_id=message.task_id,
688
- messages=[
689
- {'status': 'ERROR', 'task_id': message.task_id, 'error': error['error']},
690
- error,
691
- ],
692
- )
546
+ await self.store.set(message.reg_entry, key=message.cache_key, value=None)
693
547
 
694
- # Remove the task from the registered tasks - it finished running
695
- self.tasks.pop(message.task_id, None)
548
+ try:
549
+ async with create_task_group() as tg:
550
+ # Handle incoming messages in parallel
551
+ tg.start_soon(handle_messages)
696
552
 
697
- task_error: Optional[ExceptionGroup] = None
553
+ # Handle tasks that return other tasks
554
+ async with send_stream:
555
+ result = task
556
+ while isinstance(result, BaseTask):
557
+ result = await task.run(send_stream)
698
558
 
699
- # ExceptionGroup handler can't be async so we just mark the task as errored
700
- # and run the async handler in the finally block
701
- def handle_exception(err: ExceptionGroup):
702
- nonlocal task_error
703
- task_error = err
559
+ # Set final result
560
+ await self.set_result(task.task_id, result)
704
561
 
705
- try:
706
- with catch({BaseException: handle_exception}): # type: ignore
707
- async with create_task_group() as tg:
708
- # Handle incoming messages in parallel
709
- tg.start_soon(handle_messages)
710
-
711
- # Handle tasks that return other tasks
712
- async with send_stream:
713
- result = task
714
- while isinstance(result, BaseTask):
715
- result = await task.run(send_stream)
716
-
717
- # Notify any channels that need to be notified about the whole task being completed
718
- await send_stream.send(
719
- TaskResult(
720
- task_id=task.task_id,
721
- result=result,
722
- cache_key=task.cache_key,
723
- reg_entry=task.reg_entry,
724
- )
725
- )
726
- eng_logger.info(f'TaskManager finished task {task.task_id}', {'result': result})
727
- finally:
728
- with CancelScope(shield=True):
729
- # pyright: ignore[reportUnreachable]
730
- if task_error is not None:
731
- err = task_error
732
- # Mark pending task as failed
733
- pending_task.fail(err)
734
-
735
- await self.set_result(task.task_id, {'error': str(err)})
736
-
737
- # If the task has a cache key, set cached value to None
738
- # This makes it so that the next request will recalculate the value rather than keep failing
739
- if (
740
- task.cache_key is not None
741
- and task.reg_entry is not None
742
- and task.reg_entry.cache is not None
743
- ):
744
- await self.store.delete(task.reg_entry, key=task.cache_key)
745
-
746
- # If this is a cancellation, ensure all tasks in the hierarchy are cancelled
747
- if exception_group_contains(err_type=get_cancelled_exc_class(), group=err):
748
- dev_logger.info('Task cancelled', {'task_id': task.task_id})
749
-
750
- # Cancel any remaining tasks in the hierarchy that might still be running
751
- await self._cancel_task_hierarchy(task, notify=True)
752
- else:
753
- dev_logger.error('Task failed', err, {'task_id': task.task_id})
754
-
755
- error = get_error_for_channel(err)
756
- message = {'status': 'ERROR', 'task_id': task.task_id, 'error': error['error']}
757
- # Notify about this task failing, and a server broadcast error
758
- await self._send_notification_for_pending_task(
759
- pending_task=pending_task,
760
- messages=[message, error],
761
- )
762
- # notify related tasks
763
- await self._multicast_notification(
562
+ # Notify any channels that need to be notified about the whole task being completed
563
+ await send_stream.send(
564
+ TaskResult(
764
565
  task_id=task.task_id,
765
- messages=[message],
566
+ result=result,
567
+ cache_key=task.cache_key,
568
+ reg_entry=task.reg_entry,
766
569
  )
570
+ )
571
+ eng_logger.info(f'TaskManager finished task {task.task_id}', {'result': result})
572
+ except (Exception, ExceptionGroup) as err:
573
+ err = resolve_exception_group(err)
574
+
575
+ # Mark pending task as failed
576
+ self.tasks[task.task_id].fail(err)
767
577
 
768
- # Remove the task from the running tasks
769
- self.tasks.pop(task.task_id, None)
578
+ dev_logger.error('Task failed', err, {'task_id': task.task_id})
579
+ await self.set_result(task.task_id, {'error': str(err)})
580
+
581
+ # Notify any channels that need to be notified
582
+ await notify_channels({'status': 'ERROR', 'task_id': task.task_id}, get_error_for_channel())
583
+ finally:
584
+ # Remove the task from the running tasks
585
+ self.tasks.pop(task.task_id, None)
770
586
 
771
587
  # Make sure streams are closed
772
588
  with move_on_after(3, shield=True):
@@ -774,7 +590,7 @@ class TaskManager:
774
590
  await receive_stream.aclose()
775
591
 
776
592
 
777
- TaskResultEntry = CachedRegistryEntry(uid='task-results', cache=LruCachePolicy(max_size=256))
593
+ TaskResultEntry = CachedRegistryEntry(uid='task-results', cache=Cache.Policy.KeepAll())
778
594
  """
779
595
  Global registry entry for task results.
780
596
  This is global because task ids are unique and accessed one time only so it's effectively a one-time use random key.