dara-core 1.19.1__py3-none-any.whl → 1.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. dara/core/__init__.py +1 -0
  2. dara/core/auth/basic.py +13 -7
  3. dara/core/auth/definitions.py +2 -2
  4. dara/core/auth/utils.py +1 -1
  5. dara/core/base_definitions.py +7 -42
  6. dara/core/data_utils.py +16 -17
  7. dara/core/definitions.py +8 -8
  8. dara/core/interactivity/__init__.py +4 -0
  9. dara/core/interactivity/actions.py +20 -22
  10. dara/core/interactivity/any_data_variable.py +7 -135
  11. dara/core/interactivity/any_variable.py +1 -1
  12. dara/core/interactivity/client_variable.py +71 -0
  13. dara/core/interactivity/data_variable.py +8 -266
  14. dara/core/interactivity/derived_data_variable.py +6 -290
  15. dara/core/interactivity/derived_variable.py +335 -201
  16. dara/core/interactivity/filtering.py +29 -2
  17. dara/core/interactivity/loop_variable.py +2 -2
  18. dara/core/interactivity/non_data_variable.py +5 -68
  19. dara/core/interactivity/plain_variable.py +87 -14
  20. dara/core/interactivity/server_variable.py +325 -0
  21. dara/core/interactivity/state_variable.py +2 -2
  22. dara/core/interactivity/switch_variable.py +15 -15
  23. dara/core/interactivity/tabular_variable.py +94 -0
  24. dara/core/interactivity/url_variable.py +10 -90
  25. dara/core/internal/cache_store/cache_store.py +5 -20
  26. dara/core/internal/dependency_resolution.py +27 -69
  27. dara/core/internal/devtools.py +10 -3
  28. dara/core/internal/execute_action.py +9 -3
  29. dara/core/internal/multi_resource_lock.py +70 -0
  30. dara/core/internal/normalization.py +0 -5
  31. dara/core/internal/pandas_utils.py +105 -3
  32. dara/core/internal/pool/definitions.py +1 -1
  33. dara/core/internal/pool/task_pool.py +1 -1
  34. dara/core/internal/registries.py +3 -2
  35. dara/core/internal/registry.py +1 -1
  36. dara/core/internal/registry_lookup.py +5 -3
  37. dara/core/internal/routing.py +52 -121
  38. dara/core/internal/store.py +2 -29
  39. dara/core/internal/tasks.py +372 -182
  40. dara/core/internal/utils.py +25 -3
  41. dara/core/internal/websocket.py +1 -1
  42. dara/core/js_tooling/js_utils.py +2 -0
  43. dara/core/logging.py +10 -6
  44. dara/core/persistence.py +26 -4
  45. dara/core/umd/dara.core.umd.js +751 -1386
  46. dara/core/visual/dynamic_component.py +10 -13
  47. {dara_core-1.19.1.dist-info → dara_core-1.20.0.dist-info}/METADATA +10 -10
  48. {dara_core-1.19.1.dist-info → dara_core-1.20.0.dist-info}/RECORD +51 -47
  49. {dara_core-1.19.1.dist-info → dara_core-1.20.0.dist-info}/LICENSE +0 -0
  50. {dara_core-1.19.1.dist-info → dara_core-1.20.0.dist-info}/WHEEL +0 -0
  51. {dara_core-1.19.1.dist-info → dara_core-1.20.0.dist-info}/entry_points.txt +0 -0
@@ -19,24 +19,26 @@ import contextlib
19
19
  import inspect
20
20
  import math
21
21
  from collections.abc import Awaitable
22
- from typing import Any, Callable, Dict, List, Optional, Union, overload
22
+ from typing import Any, Callable, Dict, List, Optional, Set, Union, overload
23
23
 
24
24
  from anyio import (
25
+ BrokenResourceError,
25
26
  CancelScope,
26
27
  ClosedResourceError,
27
28
  create_memory_object_stream,
28
29
  create_task_group,
30
+ get_cancelled_exc_class,
29
31
  move_on_after,
30
32
  )
31
33
  from anyio.abc import TaskGroup
32
34
  from anyio.streams.memory import MemoryObjectSendStream
33
- from exceptiongroup import ExceptionGroup
35
+ from exceptiongroup import ExceptionGroup, catch
34
36
  from pydantic import ConfigDict
35
37
 
36
38
  from dara.core.base_definitions import (
37
39
  BaseTask,
38
- Cache,
39
40
  CachedRegistryEntry,
41
+ LruCachePolicy,
40
42
  PendingTask,
41
43
  TaskError,
42
44
  TaskMessage,
@@ -47,7 +49,7 @@ from dara.core.internal.cache_store import CacheStore
47
49
  from dara.core.internal.devtools import get_error_for_channel
48
50
  from dara.core.internal.pandas_utils import remove_index
49
51
  from dara.core.internal.pool import TaskPool
50
- from dara.core.internal.utils import resolve_exception_group, run_user_handler
52
+ from dara.core.internal.utils import exception_group_contains, run_user_handler
51
53
  from dara.core.internal.websocket import WebsocketManager
52
54
  from dara.core.logging import dev_logger, eng_logger
53
55
  from dara.core.metrics import RUNTIME_METRICS_TRACKER
@@ -214,7 +216,7 @@ class MetaTask(BaseTask):
214
216
  :param notify_channels: If this task is run in a TaskManager instance these channels will also be notified on
215
217
  completion
216
218
  :param process_as_task: Whether to run the process_result function as a task or not, defaults to False
217
- :param cache_key: Optional cache key if there is a PendingTask in the store associated with this task
219
+ :param cache_key: Optional cache key if there is a registry entry to store results for the task in
218
220
  :param task_id: Optional task_id to set for the task - otherwise the task generates its id automatically
219
221
  """
220
222
  self.args = args if args is not None else []
@@ -235,96 +237,122 @@ class MetaTask(BaseTask):
235
237
 
236
238
  :param send_stream: The stream to send messages to the task manager on
237
239
  """
238
- tasks: List[BaseTask] = []
239
-
240
- # Collect up the tasks that need to be run and kick them off without awaiting them.
241
- tasks.extend(x for x in self.args if isinstance(x, BaseTask))
242
- tasks.extend(x for x in self.kwargs.values() if isinstance(x, BaseTask))
243
-
244
- eng_logger.info(f'MetaTask {self.task_id} running sub-tasks', {'task_ids': [x.task_id for x in tasks]})
240
+ self.cancel_scope = CancelScope()
245
241
 
246
- # Wait for all tasks to complete
247
- results: Dict[str, Any] = {}
248
-
249
- async def _run_and_capture_result(task: BaseTask):
250
- """
251
- Run a task and capture the result
252
- """
253
- nonlocal results
254
- result = await task.run(send_stream)
255
- results[task.task_id] = result
256
-
257
- if len(tasks) > 0:
258
- try:
259
- async with create_task_group() as tg:
260
- self.cancel_scope = tg.cancel_scope
261
- for task in tasks:
262
- tg.start_soon(_run_and_capture_result, task)
263
- except BaseException as e:
264
- if send_stream is not None:
265
- await send_stream.send(
266
- TaskError(task_id=self.task_id, error=e, cache_key=self.cache_key, reg_entry=self.reg_entry)
242
+ try:
243
+ with self.cancel_scope:
244
+ tasks: List[BaseTask] = []
245
+
246
+ # Collect up the tasks that need to be run and kick them off without awaiting them.
247
+ tasks.extend(x for x in self.args if isinstance(x, BaseTask))
248
+ tasks.extend(x for x in self.kwargs.values() if isinstance(x, BaseTask))
249
+
250
+ eng_logger.info(f'MetaTask {self.task_id} running sub-tasks', {'task_ids': [x.task_id for x in tasks]})
251
+
252
+ # Wait for all tasks to complete
253
+ results: Dict[str, Any] = {}
254
+
255
+ async def _run_and_capture_result(task: BaseTask):
256
+ """
257
+ Run a task and capture the result
258
+ """
259
+ nonlocal results
260
+ result = await task.run(send_stream)
261
+ results[task.task_id] = result
262
+
263
+ def handle_exception(err: ExceptionGroup):
264
+ with contextlib.suppress(ClosedResourceError, BrokenResourceError):
265
+ if send_stream is not None:
266
+ send_stream.send_nowait(
267
+ TaskError(
268
+ task_id=self.task_id, error=err, cache_key=self.cache_key, reg_entry=self.reg_entry
269
+ )
270
+ )
271
+ raise err
272
+
273
+ if len(tasks) > 0:
274
+ with catch(
275
+ {
276
+ BaseException: handle_exception, # type: ignore
277
+ get_cancelled_exc_class(): handle_exception,
278
+ }
279
+ ):
280
+ async with create_task_group() as tg:
281
+ for task in tasks:
282
+ tg.start_soon(_run_and_capture_result, task)
283
+
284
+ # In testing somehow sometimes cancellations aren't bubbled up, this ensures that's the case
285
+ if tg.cancel_scope.cancel_called:
286
+ raise get_cancelled_exc_class()('MetaTask caught cancellation')
287
+
288
+ eng_logger.debug(f'MetaTask {self.task_id}', 'completed sub-tasks', results)
289
+
290
+ # Order the results in the same order as the tasks list
291
+ result_values = [results[task.task_id] for task in tasks]
292
+
293
+ args = []
294
+ kwargs = {}
295
+
296
+ # Rebuild the args and kwargs with the results of the underlying tasks
297
+ # Here the task results could be DataFrames so make sure we clean the internal __index__ col from them
298
+ # before passing into the task function
299
+ for arg in self.args:
300
+ if isinstance(arg, BaseTask):
301
+ args.append(remove_index(result_values.pop(0)))
302
+ else:
303
+ args.append(arg)
304
+
305
+ for k, val in self.kwargs.items():
306
+ if isinstance(val, BaseTask):
307
+ kwargs[k] = remove_index(result_values.pop(0))
308
+ else:
309
+ kwargs[k] = val
310
+
311
+ eng_logger.debug(f'MetaTask {self.task_id}', 'processing result', {'args': args, 'kwargs': kwargs})
312
+
313
+ # Run the process result function with the completed set of args and kwargs
314
+ if self.process_as_task:
315
+ eng_logger.debug(f'MetaTask {self.task_id}', 'processing result as Task')
316
+
317
+ task = Task(
318
+ self.process_result,
319
+ args,
320
+ kwargs,
321
+ # Pass through cache_key so the processing task correctly updates the cache store entry
322
+ reg_entry=self.reg_entry,
323
+ cache_key=self.cache_key,
324
+ task_id=self.task_id,
267
325
  )
268
- raise
269
- finally:
270
- self.cancel_scope = None
271
-
272
- eng_logger.debug(f'MetaTask {self.task_id}', 'completed sub-tasks', results)
326
+ res = await task.run(send_stream)
273
327
 
274
- # Order the results in the same order as the tasks list
275
- result_values = [results[task.task_id] for task in tasks]
328
+ eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
276
329
 
277
- args = []
278
- kwargs = {}
330
+ return res
279
331
 
280
- # Rebuild the args and kwargs with the results of the underlying tasks
281
- # Here the task results could be DataFrames so make sure we clean the internal __index__ col from them
282
- # before passing into the task function
283
- for arg in self.args:
284
- if isinstance(arg, BaseTask):
285
- args.append(remove_index(result_values.pop(0)))
286
- else:
287
- args.append(arg)
288
-
289
- for k, val in self.kwargs.items():
290
- if isinstance(val, BaseTask):
291
- kwargs[k] = remove_index(result_values.pop(0))
292
- else:
293
- kwargs[k] = val
294
-
295
- eng_logger.debug(f'MetaTask {self.task_id}', 'processing result', {'args': args, 'kwargs': kwargs})
296
-
297
- # Run the process result function with the completed set of args and kwargs
298
- if self.process_as_task:
299
- eng_logger.debug(f'MetaTask {self.task_id}', 'processing result as Task')
300
- # Pass through cache_key so the processing task correctly updates the cache store entry
301
- task = Task(self.process_result, args, kwargs, cache_key=self.cache_key)
302
- res = await task.run(send_stream)
332
+ try:
333
+ res = await run_user_handler(self.process_result, args, kwargs)
303
334
 
304
- eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
335
+ # Send MetaTask result - it could be that there is a nested structure
336
+ # of MetaTasks so we need to make sure intermediate results are also sent
337
+ if send_stream is not None:
338
+ await send_stream.send(
339
+ TaskResult(
340
+ task_id=self.task_id, result=res, cache_key=self.cache_key, reg_entry=self.reg_entry
341
+ )
342
+ )
343
+ except BaseException as e:
344
+ # Recover from error - update the pending value to prevent subsequent requests getting stuck
345
+ if send_stream is not None:
346
+ await send_stream.send(
347
+ TaskError(task_id=self.task_id, error=e, cache_key=self.cache_key, reg_entry=self.reg_entry)
348
+ )
349
+ raise
305
350
 
306
- return res
351
+ eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
307
352
 
308
- try:
309
- res = await run_user_handler(self.process_result, args, kwargs)
310
-
311
- # Send MetaTask result - it could be that there is a nested structure
312
- # of MetaTasks so we need to make sure intermediate results are also sent
313
- if send_stream is not None:
314
- await send_stream.send(
315
- TaskResult(task_id=self.task_id, result=res, cache_key=self.cache_key, reg_entry=self.reg_entry)
316
- )
317
- except BaseException as e:
318
- # Recover from error - update the pending value to prevent subsequent requests getting stuck
319
- if send_stream is not None:
320
- await send_stream.send(
321
- TaskError(task_id=self.task_id, error=e, cache_key=self.cache_key, reg_entry=self.reg_entry)
322
- )
323
- raise
324
-
325
- eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
326
-
327
- return res
353
+ return res
354
+ finally:
355
+ self.cancel_scope = None
328
356
 
329
357
  async def cancel(self):
330
358
  """
@@ -343,8 +371,10 @@ class TaskManager:
343
371
  TaskManager is responsible for running tasks and managing their pending state. It is also responsible for
344
372
  communicating the state of tasks to the client via the WebsocketManager.
345
373
 
346
- When a task is run, a PendingTask it is stored in the tasks dict. It is also stored in the store
347
- with the key of the task's cache_key. This allows the task to be retrieved by the cache_key from the store.
374
+ Every task created gets registered with the TaskManager and is tracked by the TaskManager
375
+ as a PendingTask.
376
+
377
+ This allows the task to be retrieved by the cache_key from the store.
348
378
 
349
379
  When a task is completed, it is removed from the tasks dict and the store entry is updated with the result.
350
380
 
@@ -357,6 +387,14 @@ class TaskManager:
357
387
  self.ws_manager = ws_manager
358
388
  self.store = store
359
389
 
390
+ def register_task(self, task: BaseTask) -> PendingTask:
391
+ """
392
+ Register a task. This will ensure the task it tracked and notifications are routed correctly.
393
+ """
394
+ pending_task = PendingTask(task.task_id, task)
395
+ self.tasks[task.task_id] = pending_task
396
+ return pending_task
397
+
360
398
  @overload
361
399
  async def run_task(self, task: PendingTask, ws_channel: Optional[str] = None) -> Any: ...
362
400
 
@@ -374,7 +412,9 @@ class TaskManager:
374
412
  # append the websocket channel to the task
375
413
  if isinstance(task, PendingTask):
376
414
  if task.task_id in self.tasks:
377
- if ws_channel:
415
+ # Increment subscriber count for this component request
416
+ self.tasks[task.task_id].add_subscriber()
417
+ if ws_channel is not None:
378
418
  self.tasks[task.task_id].notify_channels.append(ws_channel)
379
419
  return self.tasks[task.task_id]
380
420
 
@@ -388,21 +428,65 @@ class TaskManager:
388
428
  else self.get_result(task.task_id)
389
429
  )
390
430
 
391
- # Create and store the pending task
392
- pending_task = PendingTask(task.task_id, task, ws_channel)
393
- if task.cache_key is not None and task.reg_entry is not None:
394
- await self.store.set(task.reg_entry, key=task.cache_key, value=pending_task)
395
-
396
- self.tasks[task.task_id] = pending_task
431
+ # Otherwise, we should already have a pending task for this task
432
+ pending_task = self.tasks[task.task_id]
433
+ if ws_channel is not None:
434
+ pending_task.notify_channels.append(ws_channel)
397
435
 
398
436
  # Run the task in the background
399
- self.task_group.start_soon(self._run_task_and_notify, task, ws_channel)
437
+ self.task_group.start_soon(self._run_task_and_notify, task)
400
438
 
401
439
  return pending_task
402
440
 
441
+ async def _cancel_tasks(self, task_ids: List[str], notify: bool = True):
442
+ """
443
+ Cancel a list of tasks
444
+
445
+ :param task_ids: The list of task IDs to cancel
446
+ :param notify: Whether to send cancellation notifications
447
+ """
448
+ with CancelScope(shield=True):
449
+ # Cancel all tasks in the hierarchy
450
+ for task_id_to_cancel in task_ids:
451
+ if task_id_to_cancel in self.tasks:
452
+ pending_task = self.tasks[task_id_to_cancel]
453
+
454
+ # Notify channels that this specific task was cancelled
455
+ if notify:
456
+ await self._send_notification_for_pending_task(
457
+ pending_task=pending_task,
458
+ messages=[{'status': 'CANCELED', 'task_id': task_id_to_cancel}],
459
+ )
460
+
461
+ if not pending_task.event.is_set():
462
+ # Cancel the actual task
463
+ await pending_task.cancel()
464
+
465
+ # Remove from cache if it has cache settings
466
+ if pending_task.task_def.cache_key is not None and pending_task.task_def.reg_entry is not None:
467
+ await self.store.delete(
468
+ pending_task.task_def.reg_entry, key=pending_task.task_def.cache_key
469
+ )
470
+
471
+ # Remove from running tasks
472
+ self.tasks.pop(task_id_to_cancel, None)
473
+
474
+ dev_logger.info('Task cancelled', {'task_id': task_id_to_cancel})
475
+
476
+ async def _cancel_task_hierarchy(self, task: BaseTask, notify: bool = True):
477
+ """
478
+ Recursively cancel all tasks in a task hierarchy
479
+
480
+ :param task: The root task to cancel (and its children)
481
+ :param notify: Whether to send cancellation notifications
482
+ """
483
+ all_task_ids = self._collect_all_task_ids_in_hierarchy(task)
484
+ await self._cancel_tasks(list(all_task_ids), notify)
485
+
403
486
  async def cancel_task(self, task_id: str, notify: bool = True):
404
487
  """
405
- Cancel a running task by its id
488
+ Cancel a running task by its id. If the task has child tasks (MetaTask),
489
+ all child tasks will also be cancelled.
406
490
 
407
491
  :param task_id: the id of the task
408
492
  :param notify: whether to notify, true by default
@@ -417,22 +501,9 @@ class TaskManager:
417
501
  task.remove_subscriber()
418
502
  return
419
503
 
420
- # Notify any listening channels that the job has been cancelled so that they can handle it correctly
421
- if notify:
422
- for channel in [*task.notify_channels, *task.task_def.notify_channels]:
423
- await self.ws_manager.send_message(channel, {'status': 'CANCELED', 'task_id': task_id})
424
-
425
- # We're only now cancelling the task to make sure the clients are notified about cancelling
426
- # and receive the correct status rather than an error
427
- await task.cancel()
428
-
429
- # Then remove the pending task from cache so next requests would recalculate rather than receive
430
- # a broken pending task
431
- if task.task_def.cache_key is not None and task.task_def.reg_entry is not None:
432
- await self.store.set(task.task_def.reg_entry, key=task.task_def.cache_key, value=None)
504
+ # Cancel the entire task hierarchy (including child tasks)
505
+ await self._cancel_task_hierarchy(task.task_def, notify)
433
506
 
434
- # Remove from running tasks
435
- self.tasks.pop(task_id, None)
436
507
  else:
437
508
  raise TaskManagerError('Could not find a task with the passed id to cancel.')
438
509
 
@@ -453,12 +524,9 @@ class TaskManager:
453
524
 
454
525
  :param task_id: the id of the task to fetch
455
526
  """
456
- result = await self.store.get(TaskResultEntry, key=task_id)
457
-
458
- # Clean up the result afterwards
459
- await self.store.delete(TaskResultEntry, key=task_id)
460
-
461
- return result
527
+ # the result is not deleted, the results are kept in an LRU cache
528
+ # which will clean up older entries
529
+ return await self.store.get(TaskResultEntry, key=task_id)
462
530
 
463
531
  async def set_result(self, task_id: str, value: Any):
464
532
  """
@@ -466,7 +534,82 @@ class TaskManager:
466
534
  """
467
535
  return await self.store.set(TaskResultEntry, key=task_id, value=value)
468
536
 
469
- async def _run_task_and_notify(self, task: BaseTask, ws_channel: Optional[str]):
537
+ def _collect_all_task_ids_in_hierarchy(self, task: BaseTask) -> Set[str]:
538
+ """
539
+ Recursively collect all task IDs in the task hierarchy
540
+
541
+ :param task: The root task to start collecting from
542
+ :return: Set of all task IDs in the hierarchy
543
+ """
544
+ task_ids = {task.task_id}
545
+
546
+ if isinstance(task, MetaTask):
547
+ # Collect from args
548
+ for arg in task.args:
549
+ if isinstance(arg, BaseTask):
550
+ task_ids.update(self._collect_all_task_ids_in_hierarchy(arg))
551
+
552
+ # Collect from kwargs
553
+ for value in task.kwargs.values():
554
+ if isinstance(value, BaseTask):
555
+ task_ids.update(self._collect_all_task_ids_in_hierarchy(value))
556
+
557
+ return task_ids
558
+
559
+ async def _multicast_notification(self, task_id: str, messages: List[dict]):
560
+ """
561
+ Send notifications to all task IDs that are related to a given task
562
+
563
+ :param task: the task the notifications are related to
564
+ :param messages: List of message dictionaries to send to all related tasks
565
+ """
566
+ # prevent cancellation, we need the notifications to be sent
567
+ with CancelScope(shield=True):
568
+ # Find all PendingTasks that have the message_task_id in their hierarchy
569
+ tasks_to_notify = set()
570
+
571
+ for pending_task in self.tasks.values():
572
+ # Check if the message_task_id is in this PendingTask's hierarchy
573
+ task_ids_in_hierarchy = self._collect_all_task_ids_in_hierarchy(pending_task.task_def)
574
+ if task_id in task_ids_in_hierarchy:
575
+ tasks_to_notify.add(pending_task.task_id)
576
+
577
+ # Send notifications for all affected PendingTasks in parallel
578
+ if tasks_to_notify:
579
+ async with create_task_group() as task_tg:
580
+ for pending_task_id in tasks_to_notify:
581
+ if pending_task_id not in self.tasks:
582
+ continue
583
+ pending_task = self.tasks[pending_task_id]
584
+ task_tg.start_soon(self._send_notification_for_pending_task, pending_task, messages)
585
+
586
+ async def _send_notification_for_pending_task(self, pending_task: PendingTask, messages: List[dict]):
587
+ """
588
+ Send notifications for a specific PendingTask
589
+
590
+ :param pending_task: The PendingTask to send notifications for
591
+ :param messages: The messages to send
592
+ """
593
+ # Collect channels for this PendingTask
594
+ channels_to_notify = set(pending_task.notify_channels)
595
+ channels_to_notify.update(pending_task.task_def.notify_channels)
596
+
597
+ if not channels_to_notify:
598
+ return
599
+
600
+ # Send to all channels for this PendingTask in parallel
601
+ async def _send_to_channel(channel: str):
602
+ async with create_task_group() as channel_tg:
603
+ for message in messages:
604
+ # Create message with this PendingTask's task_id (if message has task_id)
605
+ message_for_task = {**message, 'task_id': pending_task.task_id} if 'task_id' in message else message
606
+ channel_tg.start_soon(self.ws_manager.send_message, channel, message_for_task)
607
+
608
+ async with create_task_group() as channel_tg:
609
+ for channel in channels_to_notify:
610
+ channel_tg.start_soon(_send_to_channel, channel)
611
+
612
+ async def _run_task_and_notify(self, task: BaseTask):
470
613
  """
471
614
  Run the task to completion and notify the client of progress and completion
472
615
 
@@ -475,23 +618,13 @@ class TaskManager:
475
618
  """
476
619
  cancel_scope = CancelScope()
477
620
 
478
- self.tasks[task.task_id].cancel_scope = cancel_scope
621
+ pending_task = self.tasks[task.task_id]
622
+
623
+ pending_task.cancel_scope = cancel_scope
479
624
 
480
625
  with cancel_scope:
481
626
  eng_logger.info(f'TaskManager running task {task.task_id}')
482
627
 
483
- async def notify_channels(*messages: dict):
484
- """
485
- Notify the channels of the task's progress
486
- """
487
- channels_to_notify = [*task.notify_channels]
488
- if ws_channel:
489
- channels_to_notify.append(ws_channel)
490
-
491
- for channel in channels_to_notify:
492
- for message in messages:
493
- await self.ws_manager.send_message(channel, message)
494
-
495
628
  # Create a memory object stream to capture messages from the tasks
496
629
  send_stream, receive_stream = create_memory_object_stream[TaskMessage](math.inf)
497
630
 
@@ -499,32 +632,41 @@ class TaskManager:
499
632
  async with receive_stream:
500
633
  async for message in receive_stream:
501
634
  if isinstance(message, TaskProgressUpdate):
502
- # Notify the channels of the task's progress
503
- await notify_channels(
504
- {
505
- 'task_id': task.task_id,
506
- 'status': 'PROGRESS',
507
- 'progress': message.progress,
508
- 'message': message.message,
509
- }
510
- )
635
+ # Send progress notifications to related tasks
636
+ progress_message = {
637
+ 'task_id': message.task_id, # Will be updated per task ID in multicast
638
+ 'status': 'PROGRESS',
639
+ 'progress': message.progress,
640
+ 'message': message.message,
641
+ }
642
+ await self._multicast_notification(message.task_id, [progress_message])
511
643
  if isinstance(task, Task) and task.on_progress:
512
644
  await run_user_handler(task.on_progress, args=(message,))
513
645
  elif isinstance(message, TaskResult):
514
- # Resolve the pending task related to the result
646
+ # Handle dual coordination patterns:
647
+ # 1. Direct-coordinated tasks: resolve via active_tasks registry
515
648
  if message.task_id in self.tasks:
516
- self.tasks[task.task_id].resolve(message.result)
517
- # If the task has a cache key, update the cached value
649
+ self.tasks[message.task_id].resolve(message.result)
650
+
651
+ # 2. Cache-coordinated tasks: resolve via cache store (CacheStore.set handles PendingTask resolution)
518
652
  if (
519
653
  message.cache_key is not None
520
654
  and message.reg_entry is not None
521
655
  and message.reg_entry.cache is not None
522
656
  ):
523
657
  await self.store.set(message.reg_entry, key=message.cache_key, value=message.result)
524
- # Notify the channels of the task's completion
525
- await notify_channels(
526
- {'result': message.result, 'status': 'COMPLETE', 'task_id': message.task_id}
658
+
659
+ # Set final result
660
+ await self.set_result(message.task_id, message.result)
661
+
662
+ # Notify all PendingTasks that depend on this specific task
663
+ await self._multicast_notification(
664
+ task_id=message.task_id,
665
+ messages=[{'result': message.result, 'status': 'COMPLETE', 'task_id': message.task_id}],
527
666
  )
667
+
668
+ # Remove the task from the registered tasks - it finished running
669
+ self.tasks.pop(message.task_id, None)
528
670
  elif isinstance(message, TaskError):
529
671
  # Fail the pending task related to the error
530
672
  if message.task_id in self.tasks:
@@ -537,46 +679,94 @@ class TaskManager:
537
679
  and message.reg_entry is not None
538
680
  and message.reg_entry.cache is not None
539
681
  ):
540
- await self.store.set(message.reg_entry, key=message.cache_key, value=None)
682
+ await self.store.delete(message.reg_entry, key=message.cache_key)
683
+
684
+ # Notify all PendingTasks that depend on this specific task
685
+ error = get_error_for_channel(message.error)
686
+ await self._multicast_notification(
687
+ task_id=message.task_id,
688
+ messages=[
689
+ {'status': 'ERROR', 'task_id': message.task_id, 'error': error['error']},
690
+ error,
691
+ ],
692
+ )
541
693
 
542
- try:
543
- async with create_task_group() as tg:
544
- # Handle incoming messages in parallel
545
- tg.start_soon(handle_messages)
694
+ # Remove the task from the registered tasks - it finished running
695
+ self.tasks.pop(message.task_id, None)
546
696
 
547
- # Handle tasks that return other tasks
548
- async with send_stream:
549
- result = task
550
- while isinstance(result, BaseTask):
551
- result = await task.run(send_stream)
697
+ task_error: Optional[ExceptionGroup] = None
552
698
 
553
- # Set final result
554
- await self.set_result(task.task_id, result)
699
+ # ExceptionGroup handler can't be async so we just mark the task as errored
700
+ # and run the async handler in the finally block
701
+ def handle_exception(err: ExceptionGroup):
702
+ nonlocal task_error
703
+ task_error = err
555
704
 
556
- # Notify any channels that need to be notified about the whole task being completed
557
- await send_stream.send(
558
- TaskResult(
705
+ try:
706
+ with catch({BaseException: handle_exception}): # type: ignore
707
+ async with create_task_group() as tg:
708
+ # Handle incoming messages in parallel
709
+ tg.start_soon(handle_messages)
710
+
711
+ # Handle tasks that return other tasks
712
+ async with send_stream:
713
+ result = task
714
+ while isinstance(result, BaseTask):
715
+ result = await task.run(send_stream)
716
+
717
+ # Notify any channels that need to be notified about the whole task being completed
718
+ await send_stream.send(
719
+ TaskResult(
720
+ task_id=task.task_id,
721
+ result=result,
722
+ cache_key=task.cache_key,
723
+ reg_entry=task.reg_entry,
724
+ )
725
+ )
726
+ eng_logger.info(f'TaskManager finished task {task.task_id}', {'result': result})
727
+ finally:
728
+ with CancelScope(shield=True):
729
+ # pyright: ignore[reportUnreachable]
730
+ if task_error is not None:
731
+ err = task_error
732
+ # Mark pending task as failed
733
+ pending_task.fail(err)
734
+
735
+ await self.set_result(task.task_id, {'error': str(err)})
736
+
737
+ # If the task has a cache key, set cached value to None
738
+ # This makes it so that the next request will recalculate the value rather than keep failing
739
+ if (
740
+ task.cache_key is not None
741
+ and task.reg_entry is not None
742
+ and task.reg_entry.cache is not None
743
+ ):
744
+ await self.store.delete(task.reg_entry, key=task.cache_key)
745
+
746
+ # If this is a cancellation, ensure all tasks in the hierarchy are cancelled
747
+ if exception_group_contains(err_type=get_cancelled_exc_class(), group=err):
748
+ dev_logger.info('Task cancelled', {'task_id': task.task_id})
749
+
750
+ # Cancel any remaining tasks in the hierarchy that might still be running
751
+ await self._cancel_task_hierarchy(task, notify=True)
752
+ else:
753
+ dev_logger.error('Task failed', err, {'task_id': task.task_id})
754
+
755
+ error = get_error_for_channel(err)
756
+ message = {'status': 'ERROR', 'task_id': task.task_id, 'error': error['error']}
757
+ # Notify about this task failing, and a server broadcast error
758
+ await self._send_notification_for_pending_task(
759
+ pending_task=pending_task,
760
+ messages=[message, error],
761
+ )
762
+ # notify related tasks
763
+ await self._multicast_notification(
559
764
  task_id=task.task_id,
560
- result=result,
561
- cache_key=task.cache_key,
562
- reg_entry=task.reg_entry,
765
+ messages=[message],
563
766
  )
564
- )
565
- eng_logger.info(f'TaskManager finished task {task.task_id}', {'result': result})
566
- except (Exception, ExceptionGroup) as err:
567
- err = resolve_exception_group(err)
568
-
569
- # Mark pending task as failed
570
- self.tasks[task.task_id].fail(err)
571
767
 
572
- dev_logger.error('Task failed', err, {'task_id': task.task_id})
573
- await self.set_result(task.task_id, {'error': str(err)})
574
-
575
- # Notify any channels that need to be notified
576
- await notify_channels({'status': 'ERROR', 'task_id': task.task_id}, get_error_for_channel())
577
- finally:
578
- # Remove the task from the running tasks
579
- self.tasks.pop(task.task_id, None)
768
+ # Remove the task from the running tasks
769
+ self.tasks.pop(task.task_id, None)
580
770
 
581
771
  # Make sure streams are closed
582
772
  with move_on_after(3, shield=True):
@@ -584,7 +774,7 @@ class TaskManager:
584
774
  await receive_stream.aclose()
585
775
 
586
776
 
587
- TaskResultEntry = CachedRegistryEntry(uid='task-results', cache=Cache.Policy.KeepAll())
777
+ TaskResultEntry = CachedRegistryEntry(uid='task-results', cache=LruCachePolicy(max_size=256))
588
778
  """
589
779
  Global registry entry for task results.
590
780
  This is global because task ids are unique and accessed one time only so it's effectively a one-time use random key.