dara-core 1.20.0a1__py3-none-any.whl → 1.20.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dara/core/__init__.py +0 -3
- dara/core/actions.py +2 -1
- dara/core/auth/basic.py +16 -22
- dara/core/auth/definitions.py +2 -2
- dara/core/auth/routes.py +5 -5
- dara/core/auth/utils.py +5 -5
- dara/core/base_definitions.py +64 -22
- dara/core/cli.py +7 -8
- dara/core/configuration.py +2 -5
- dara/core/css.py +2 -1
- dara/core/data_utils.py +19 -18
- dara/core/defaults.py +7 -6
- dara/core/definitions.py +19 -50
- dara/core/http.py +3 -7
- dara/core/interactivity/__init__.py +0 -6
- dara/core/interactivity/actions.py +50 -52
- dara/core/interactivity/any_data_variable.py +134 -7
- dara/core/interactivity/any_variable.py +8 -5
- dara/core/interactivity/data_variable.py +266 -8
- dara/core/interactivity/derived_data_variable.py +290 -7
- dara/core/interactivity/derived_variable.py +174 -414
- dara/core/interactivity/filtering.py +27 -46
- dara/core/interactivity/loop_variable.py +2 -2
- dara/core/interactivity/non_data_variable.py +68 -5
- dara/core/interactivity/plain_variable.py +15 -89
- dara/core/interactivity/switch_variable.py +19 -19
- dara/core/interactivity/url_variable.py +90 -10
- dara/core/internal/cache_store/base_impl.py +1 -2
- dara/core/internal/cache_store/cache_store.py +25 -22
- dara/core/internal/cache_store/keep_all.py +1 -4
- dara/core/internal/cache_store/lru.py +1 -5
- dara/core/internal/cache_store/ttl.py +1 -4
- dara/core/internal/cgroup.py +1 -1
- dara/core/internal/dependency_resolution.py +66 -60
- dara/core/internal/devtools.py +5 -12
- dara/core/internal/download.py +4 -13
- dara/core/internal/encoder_registry.py +7 -7
- dara/core/internal/execute_action.py +13 -13
- dara/core/internal/hashing.py +3 -1
- dara/core/internal/import_discovery.py +4 -3
- dara/core/internal/normalization.py +18 -9
- dara/core/internal/pandas_utils.py +5 -107
- dara/core/internal/pool/definitions.py +1 -1
- dara/core/internal/pool/task_pool.py +16 -25
- dara/core/internal/pool/utils.py +18 -21
- dara/core/internal/pool/worker.py +2 -3
- dara/core/internal/port_utils.py +1 -1
- dara/core/internal/registries.py +6 -12
- dara/core/internal/registry.py +2 -4
- dara/core/internal/registry_lookup.py +5 -11
- dara/core/internal/routing.py +145 -109
- dara/core/internal/scheduler.py +8 -13
- dara/core/internal/settings.py +2 -2
- dara/core/internal/store.py +29 -2
- dara/core/internal/tasks.py +195 -379
- dara/core/internal/utils.py +13 -36
- dara/core/internal/websocket.py +20 -21
- dara/core/js_tooling/js_utils.py +26 -28
- dara/core/js_tooling/templates/vite.config.template.ts +3 -12
- dara/core/logging.py +12 -13
- dara/core/main.py +11 -14
- dara/core/metrics/cache.py +1 -1
- dara/core/metrics/utils.py +3 -3
- dara/core/persistence.py +5 -27
- dara/core/umd/dara.core.umd.js +55425 -59091
- dara/core/visual/components/__init__.py +2 -2
- dara/core/visual/components/fallback.py +4 -30
- dara/core/visual/components/for_cmp.py +1 -4
- dara/core/visual/css/__init__.py +31 -30
- dara/core/visual/dynamic_component.py +28 -31
- dara/core/visual/progress_updater.py +3 -4
- {dara_core-1.20.0a1.dist-info → dara_core-1.20.1a1.dist-info}/METADATA +11 -12
- dara_core-1.20.1a1.dist-info/RECORD +114 -0
- dara/core/interactivity/client_variable.py +0 -71
- dara/core/interactivity/server_variable.py +0 -325
- dara/core/interactivity/state_variable.py +0 -69
- dara/core/interactivity/tabular_variable.py +0 -94
- dara/core/internal/multi_resource_lock.py +0 -70
- dara_core-1.20.0a1.dist-info/RECORD +0 -119
- {dara_core-1.20.0a1.dist-info → dara_core-1.20.1a1.dist-info}/LICENSE +0 -0
- {dara_core-1.20.0a1.dist-info → dara_core-1.20.1a1.dist-info}/WHEEL +0 -0
- {dara_core-1.20.0a1.dist-info → dara_core-1.20.1a1.dist-info}/entry_points.txt +0 -0
dara/core/internal/tasks.py
CHANGED
|
@@ -15,30 +15,26 @@ See the License for the specific language governing permissions and
|
|
|
15
15
|
limitations under the License.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
-
import contextlib
|
|
19
18
|
import inspect
|
|
20
19
|
import math
|
|
21
|
-
from
|
|
22
|
-
from typing import Any, Callable, Dict, List, Optional, Set, Union, overload
|
|
20
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union, overload
|
|
23
21
|
|
|
24
22
|
from anyio import (
|
|
25
|
-
BrokenResourceError,
|
|
26
23
|
CancelScope,
|
|
27
24
|
ClosedResourceError,
|
|
28
25
|
create_memory_object_stream,
|
|
29
26
|
create_task_group,
|
|
30
|
-
get_cancelled_exc_class,
|
|
31
27
|
move_on_after,
|
|
32
28
|
)
|
|
33
29
|
from anyio.abc import TaskGroup
|
|
34
30
|
from anyio.streams.memory import MemoryObjectSendStream
|
|
35
|
-
from exceptiongroup import ExceptionGroup
|
|
31
|
+
from exceptiongroup import ExceptionGroup
|
|
36
32
|
from pydantic import ConfigDict
|
|
37
33
|
|
|
38
34
|
from dara.core.base_definitions import (
|
|
39
35
|
BaseTask,
|
|
36
|
+
Cache,
|
|
40
37
|
CachedRegistryEntry,
|
|
41
|
-
LruCachePolicy,
|
|
42
38
|
PendingTask,
|
|
43
39
|
TaskError,
|
|
44
40
|
TaskMessage,
|
|
@@ -49,7 +45,7 @@ from dara.core.internal.cache_store import CacheStore
|
|
|
49
45
|
from dara.core.internal.devtools import get_error_for_channel
|
|
50
46
|
from dara.core.internal.pandas_utils import remove_index
|
|
51
47
|
from dara.core.internal.pool import TaskPool
|
|
52
|
-
from dara.core.internal.utils import
|
|
48
|
+
from dara.core.internal.utils import resolve_exception_group, run_user_handler
|
|
53
49
|
from dara.core.internal.websocket import WebsocketManager
|
|
54
50
|
from dara.core.logging import dev_logger, eng_logger
|
|
55
51
|
from dara.core.metrics import RUNTIME_METRICS_TRACKER
|
|
@@ -141,12 +137,14 @@ class Task(BaseTask):
|
|
|
141
137
|
|
|
142
138
|
async def on_progress(progress: float, msg: str):
|
|
143
139
|
if send_stream is not None:
|
|
144
|
-
|
|
140
|
+
try:
|
|
145
141
|
await send_stream.send(TaskProgressUpdate(task_id=self.task_id, progress=progress, message=msg))
|
|
142
|
+
except ClosedResourceError:
|
|
143
|
+
pass
|
|
146
144
|
|
|
147
145
|
async def on_result(result: Any):
|
|
148
146
|
if send_stream is not None:
|
|
149
|
-
|
|
147
|
+
try:
|
|
150
148
|
await send_stream.send(
|
|
151
149
|
TaskResult(
|
|
152
150
|
task_id=self.task_id,
|
|
@@ -155,15 +153,19 @@ class Task(BaseTask):
|
|
|
155
153
|
reg_entry=self.reg_entry,
|
|
156
154
|
)
|
|
157
155
|
)
|
|
156
|
+
except ClosedResourceError:
|
|
157
|
+
pass
|
|
158
158
|
|
|
159
159
|
async def on_error(exc: BaseException):
|
|
160
160
|
if send_stream is not None:
|
|
161
|
-
|
|
161
|
+
try:
|
|
162
162
|
await send_stream.send(
|
|
163
163
|
TaskError(
|
|
164
164
|
task_id=self.task_id, error=exc, cache_key=self.cache_key, reg_entry=self.reg_entry
|
|
165
165
|
)
|
|
166
166
|
)
|
|
167
|
+
except ClosedResourceError:
|
|
168
|
+
pass
|
|
167
169
|
|
|
168
170
|
with pool.on_progress(self.task_id, on_progress):
|
|
169
171
|
pool_task_def = pool.submit(self.task_id, self._func_name, args=tuple(self._args), kwargs=self._kwargs)
|
|
@@ -216,7 +218,7 @@ class MetaTask(BaseTask):
|
|
|
216
218
|
:param notify_channels: If this task is run in a TaskManager instance these channels will also be notified on
|
|
217
219
|
completion
|
|
218
220
|
:param process_as_task: Whether to run the process_result function as a task or not, defaults to False
|
|
219
|
-
:param cache_key: Optional cache key if there is a
|
|
221
|
+
:param cache_key: Optional cache key if there is a PendingTask in the store associated with this task
|
|
220
222
|
:param task_id: Optional task_id to set for the task - otherwise the task generates its id automatically
|
|
221
223
|
"""
|
|
222
224
|
self.args = args if args is not None else []
|
|
@@ -237,122 +239,96 @@ class MetaTask(BaseTask):
|
|
|
237
239
|
|
|
238
240
|
:param send_stream: The stream to send messages to the task manager on
|
|
239
241
|
"""
|
|
240
|
-
|
|
242
|
+
tasks: List[BaseTask] = []
|
|
241
243
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
)
|
|
270
|
-
)
|
|
271
|
-
raise err
|
|
272
|
-
|
|
273
|
-
if len(tasks) > 0:
|
|
274
|
-
with catch(
|
|
275
|
-
{
|
|
276
|
-
BaseException: handle_exception, # type: ignore
|
|
277
|
-
get_cancelled_exc_class(): handle_exception,
|
|
278
|
-
}
|
|
279
|
-
):
|
|
280
|
-
async with create_task_group() as tg:
|
|
281
|
-
for task in tasks:
|
|
282
|
-
tg.start_soon(_run_and_capture_result, task)
|
|
283
|
-
|
|
284
|
-
# In testing somehow sometimes cancellations aren't bubbled up, this ensures that's the case
|
|
285
|
-
if tg.cancel_scope.cancel_called:
|
|
286
|
-
raise get_cancelled_exc_class()('MetaTask caught cancellation')
|
|
287
|
-
|
|
288
|
-
eng_logger.debug(f'MetaTask {self.task_id}', 'completed sub-tasks', results)
|
|
289
|
-
|
|
290
|
-
# Order the results in the same order as the tasks list
|
|
291
|
-
result_values = [results[task.task_id] for task in tasks]
|
|
292
|
-
|
|
293
|
-
args = []
|
|
294
|
-
kwargs = {}
|
|
295
|
-
|
|
296
|
-
# Rebuild the args and kwargs with the results of the underlying tasks
|
|
297
|
-
# Here the task results could be DataFrames so make sure we clean the internal __index__ col from them
|
|
298
|
-
# before passing into the task function
|
|
299
|
-
for arg in self.args:
|
|
300
|
-
if isinstance(arg, BaseTask):
|
|
301
|
-
args.append(remove_index(result_values.pop(0)))
|
|
302
|
-
else:
|
|
303
|
-
args.append(arg)
|
|
304
|
-
|
|
305
|
-
for k, val in self.kwargs.items():
|
|
306
|
-
if isinstance(val, BaseTask):
|
|
307
|
-
kwargs[k] = remove_index(result_values.pop(0))
|
|
308
|
-
else:
|
|
309
|
-
kwargs[k] = val
|
|
310
|
-
|
|
311
|
-
eng_logger.debug(f'MetaTask {self.task_id}', 'processing result', {'args': args, 'kwargs': kwargs})
|
|
312
|
-
|
|
313
|
-
# Run the process result function with the completed set of args and kwargs
|
|
314
|
-
if self.process_as_task:
|
|
315
|
-
eng_logger.debug(f'MetaTask {self.task_id}', 'processing result as Task')
|
|
316
|
-
|
|
317
|
-
task = Task(
|
|
318
|
-
self.process_result,
|
|
319
|
-
args,
|
|
320
|
-
kwargs,
|
|
321
|
-
# Pass through cache_key so the processing task correctly updates the cache store entry
|
|
322
|
-
reg_entry=self.reg_entry,
|
|
323
|
-
cache_key=self.cache_key,
|
|
324
|
-
task_id=self.task_id,
|
|
244
|
+
# Collect up the tasks that need to be run and kick them off without awaiting them.
|
|
245
|
+
tasks.extend(x for x in self.args if isinstance(x, BaseTask))
|
|
246
|
+
tasks.extend(x for x in self.kwargs.values() if isinstance(x, BaseTask))
|
|
247
|
+
|
|
248
|
+
eng_logger.info(f'MetaTask {self.task_id} running sub-tasks', {'task_ids': [x.task_id for x in tasks]})
|
|
249
|
+
|
|
250
|
+
# Wait for all tasks to complete
|
|
251
|
+
results: Dict[str, Any] = {}
|
|
252
|
+
|
|
253
|
+
async def _run_and_capture_result(task: BaseTask):
|
|
254
|
+
"""
|
|
255
|
+
Run a task and capture the result
|
|
256
|
+
"""
|
|
257
|
+
nonlocal results
|
|
258
|
+
result = await task.run(send_stream)
|
|
259
|
+
results[task.task_id] = result
|
|
260
|
+
|
|
261
|
+
if len(tasks) > 0:
|
|
262
|
+
try:
|
|
263
|
+
async with create_task_group() as tg:
|
|
264
|
+
self.cancel_scope = tg.cancel_scope
|
|
265
|
+
for task in tasks:
|
|
266
|
+
tg.start_soon(_run_and_capture_result, task)
|
|
267
|
+
except BaseException as e:
|
|
268
|
+
if send_stream is not None:
|
|
269
|
+
await send_stream.send(
|
|
270
|
+
TaskError(task_id=self.task_id, error=e, cache_key=self.cache_key, reg_entry=self.reg_entry)
|
|
325
271
|
)
|
|
326
|
-
|
|
272
|
+
raise
|
|
273
|
+
finally:
|
|
274
|
+
self.cancel_scope = None
|
|
327
275
|
|
|
328
|
-
|
|
276
|
+
eng_logger.debug(f'MetaTask {self.task_id}', 'completed sub-tasks', results)
|
|
329
277
|
|
|
330
|
-
|
|
278
|
+
# Order the results in the same order as the tasks list
|
|
279
|
+
result_values = [results[task.task_id] for task in tasks]
|
|
331
280
|
|
|
332
|
-
|
|
333
|
-
|
|
281
|
+
args = []
|
|
282
|
+
kwargs = {}
|
|
334
283
|
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
284
|
+
# Rebuild the args and kwargs with the results of the underlying tasks
|
|
285
|
+
# Here the task results could be DataFrames so make sure we clean the internal __index__ col from them
|
|
286
|
+
# before passing into the task function
|
|
287
|
+
for arg in self.args:
|
|
288
|
+
if isinstance(arg, BaseTask):
|
|
289
|
+
args.append(remove_index(result_values.pop(0)))
|
|
290
|
+
else:
|
|
291
|
+
args.append(arg)
|
|
292
|
+
|
|
293
|
+
for k, val in self.kwargs.items():
|
|
294
|
+
if isinstance(val, BaseTask):
|
|
295
|
+
kwargs[k] = remove_index(result_values.pop(0))
|
|
296
|
+
else:
|
|
297
|
+
kwargs[k] = val
|
|
298
|
+
|
|
299
|
+
eng_logger.debug(f'MetaTask {self.task_id}', 'processing result', {'args': args, 'kwargs': kwargs})
|
|
350
300
|
|
|
351
|
-
|
|
301
|
+
# Run the process result function with the completed set of args and kwargs
|
|
302
|
+
if self.process_as_task:
|
|
303
|
+
eng_logger.debug(f'MetaTask {self.task_id}', 'processing result as Task')
|
|
304
|
+
# Pass through cache_key so the processing task correctly updates the cache store entry
|
|
305
|
+
task = Task(self.process_result, args, kwargs, cache_key=self.cache_key)
|
|
306
|
+
res = await task.run(send_stream)
|
|
352
307
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
308
|
+
eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
|
|
309
|
+
|
|
310
|
+
return res
|
|
311
|
+
|
|
312
|
+
try:
|
|
313
|
+
res = await run_user_handler(self.process_result, args, kwargs)
|
|
314
|
+
|
|
315
|
+
# Send MetaTask result - it could be that there is a nested structure
|
|
316
|
+
# of MetaTasks so we need to make sure intermediate results are also sent
|
|
317
|
+
if send_stream is not None:
|
|
318
|
+
await send_stream.send(
|
|
319
|
+
TaskResult(task_id=self.task_id, result=res, cache_key=self.cache_key, reg_entry=self.reg_entry)
|
|
320
|
+
)
|
|
321
|
+
except BaseException as e:
|
|
322
|
+
# Recover from error - update the pending value to prevent subsequent requests getting stuck
|
|
323
|
+
if send_stream is not None:
|
|
324
|
+
await send_stream.send(
|
|
325
|
+
TaskError(task_id=self.task_id, error=e, cache_key=self.cache_key, reg_entry=self.reg_entry)
|
|
326
|
+
)
|
|
327
|
+
raise
|
|
328
|
+
|
|
329
|
+
eng_logger.info(f'MetaTask {self.task_id} returning result', {'result': res})
|
|
330
|
+
|
|
331
|
+
return res
|
|
356
332
|
|
|
357
333
|
async def cancel(self):
|
|
358
334
|
"""
|
|
@@ -371,10 +347,8 @@ class TaskManager:
|
|
|
371
347
|
TaskManager is responsible for running tasks and managing their pending state. It is also responsible for
|
|
372
348
|
communicating the state of tasks to the client via the WebsocketManager.
|
|
373
349
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
This allows the task to be retrieved by the cache_key from the store.
|
|
350
|
+
When a task is run, a PendingTask it is stored in the tasks dict. It is also stored in the store
|
|
351
|
+
with the key of the task's cache_key. This allows the task to be retrieved by the cache_key from the store.
|
|
378
352
|
|
|
379
353
|
When a task is completed, it is removed from the tasks dict and the store entry is updated with the result.
|
|
380
354
|
|
|
@@ -387,19 +361,13 @@ class TaskManager:
|
|
|
387
361
|
self.ws_manager = ws_manager
|
|
388
362
|
self.store = store
|
|
389
363
|
|
|
390
|
-
def register_task(self, task: BaseTask) -> PendingTask:
|
|
391
|
-
"""
|
|
392
|
-
Register a task. This will ensure the task it tracked and notifications are routed correctly.
|
|
393
|
-
"""
|
|
394
|
-
pending_task = PendingTask(task.task_id, task)
|
|
395
|
-
self.tasks[task.task_id] = pending_task
|
|
396
|
-
return pending_task
|
|
397
|
-
|
|
398
364
|
@overload
|
|
399
|
-
async def run_task(self, task: PendingTask, ws_channel: Optional[str] = None) -> Any:
|
|
365
|
+
async def run_task(self, task: PendingTask, ws_channel: Optional[str] = None) -> Any:
|
|
366
|
+
...
|
|
400
367
|
|
|
401
368
|
@overload
|
|
402
|
-
async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None) -> PendingTask:
|
|
369
|
+
async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None) -> PendingTask:
|
|
370
|
+
...
|
|
403
371
|
|
|
404
372
|
async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None):
|
|
405
373
|
"""
|
|
@@ -412,9 +380,7 @@ class TaskManager:
|
|
|
412
380
|
# append the websocket channel to the task
|
|
413
381
|
if isinstance(task, PendingTask):
|
|
414
382
|
if task.task_id in self.tasks:
|
|
415
|
-
|
|
416
|
-
self.tasks[task.task_id].add_subscriber()
|
|
417
|
-
if ws_channel is not None:
|
|
383
|
+
if ws_channel:
|
|
418
384
|
self.tasks[task.task_id].notify_channels.append(ws_channel)
|
|
419
385
|
return self.tasks[task.task_id]
|
|
420
386
|
|
|
@@ -428,65 +394,21 @@ class TaskManager:
|
|
|
428
394
|
else self.get_result(task.task_id)
|
|
429
395
|
)
|
|
430
396
|
|
|
431
|
-
#
|
|
432
|
-
pending_task =
|
|
433
|
-
if
|
|
434
|
-
|
|
397
|
+
# Create and store the pending task
|
|
398
|
+
pending_task = PendingTask(task.task_id, task, ws_channel)
|
|
399
|
+
if task.cache_key is not None and task.reg_entry is not None:
|
|
400
|
+
await self.store.set(task.reg_entry, key=task.cache_key, value=pending_task)
|
|
401
|
+
|
|
402
|
+
self.tasks[task.task_id] = pending_task
|
|
435
403
|
|
|
436
404
|
# Run the task in the background
|
|
437
|
-
self.task_group.start_soon(self._run_task_and_notify, task)
|
|
405
|
+
self.task_group.start_soon(self._run_task_and_notify, task, ws_channel)
|
|
438
406
|
|
|
439
407
|
return pending_task
|
|
440
408
|
|
|
441
|
-
async def _cancel_tasks(self, task_ids: List[str], notify: bool = True):
|
|
442
|
-
"""
|
|
443
|
-
Cancel a list of tasks
|
|
444
|
-
|
|
445
|
-
:param task_ids: The list of task IDs to cancel
|
|
446
|
-
:param notify: Whether to send cancellation notifications
|
|
447
|
-
"""
|
|
448
|
-
with CancelScope(shield=True):
|
|
449
|
-
# Cancel all tasks in the hierarchy
|
|
450
|
-
for task_id_to_cancel in task_ids:
|
|
451
|
-
if task_id_to_cancel in self.tasks:
|
|
452
|
-
pending_task = self.tasks[task_id_to_cancel]
|
|
453
|
-
|
|
454
|
-
# Notify channels that this specific task was cancelled
|
|
455
|
-
if notify:
|
|
456
|
-
await self._send_notification_for_pending_task(
|
|
457
|
-
pending_task=pending_task,
|
|
458
|
-
messages=[{'status': 'CANCELED', 'task_id': task_id_to_cancel}],
|
|
459
|
-
)
|
|
460
|
-
|
|
461
|
-
if not pending_task.event.is_set():
|
|
462
|
-
# Cancel the actual task
|
|
463
|
-
await pending_task.cancel()
|
|
464
|
-
|
|
465
|
-
# Remove from cache if it has cache settings
|
|
466
|
-
if pending_task.task_def.cache_key is not None and pending_task.task_def.reg_entry is not None:
|
|
467
|
-
await self.store.delete(
|
|
468
|
-
pending_task.task_def.reg_entry, key=pending_task.task_def.cache_key
|
|
469
|
-
)
|
|
470
|
-
|
|
471
|
-
# Remove from running tasks
|
|
472
|
-
self.tasks.pop(task_id_to_cancel, None)
|
|
473
|
-
|
|
474
|
-
dev_logger.info('Task cancelled', {'task_id': task_id_to_cancel})
|
|
475
|
-
|
|
476
|
-
async def _cancel_task_hierarchy(self, task: BaseTask, notify: bool = True):
|
|
477
|
-
"""
|
|
478
|
-
Recursively cancel all tasks in a task hierarchy
|
|
479
|
-
|
|
480
|
-
:param task: The root task to cancel (and its children)
|
|
481
|
-
:param notify: Whether to send cancellation notifications
|
|
482
|
-
"""
|
|
483
|
-
all_task_ids = self._collect_all_task_ids_in_hierarchy(task)
|
|
484
|
-
await self._cancel_tasks(list(all_task_ids), notify)
|
|
485
|
-
|
|
486
409
|
async def cancel_task(self, task_id: str, notify: bool = True):
|
|
487
410
|
"""
|
|
488
|
-
Cancel a running task by its id
|
|
489
|
-
all child tasks will also be cancelled.
|
|
411
|
+
Cancel a running task by its id
|
|
490
412
|
|
|
491
413
|
:param task_id: the id of the task
|
|
492
414
|
:param notify: whether to notify, true by default
|
|
@@ -501,9 +423,22 @@ class TaskManager:
|
|
|
501
423
|
task.remove_subscriber()
|
|
502
424
|
return
|
|
503
425
|
|
|
504
|
-
#
|
|
505
|
-
|
|
426
|
+
# Notify any listening channels that the job has been cancelled so that they can handle it correctly
|
|
427
|
+
if notify:
|
|
428
|
+
for channel in [*task.notify_channels, *task.task_def.notify_channels]:
|
|
429
|
+
await self.ws_manager.send_message(channel, {'status': 'CANCELED', 'task_id': task_id})
|
|
430
|
+
|
|
431
|
+
# We're only now cancelling the task to make sure the clients are notified about cancelling
|
|
432
|
+
# and receive the correct status rather than an error
|
|
433
|
+
await task.cancel()
|
|
434
|
+
|
|
435
|
+
# Then remove the pending task from cache so next requests would recalculate rather than receive
|
|
436
|
+
# a broken pending task
|
|
437
|
+
if task.task_def.cache_key is not None and task.task_def.reg_entry is not None:
|
|
438
|
+
await self.store.set(task.task_def.reg_entry, key=task.task_def.cache_key, value=None)
|
|
506
439
|
|
|
440
|
+
# Remove from running tasks
|
|
441
|
+
self.tasks.pop(task_id, None)
|
|
507
442
|
else:
|
|
508
443
|
raise TaskManagerError('Could not find a task with the passed id to cancel.')
|
|
509
444
|
|
|
@@ -524,9 +459,12 @@ class TaskManager:
|
|
|
524
459
|
|
|
525
460
|
:param task_id: the id of the task to fetch
|
|
526
461
|
"""
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
462
|
+
result = await self.store.get(TaskResultEntry, key=task_id)
|
|
463
|
+
|
|
464
|
+
# Clean up the result afterwards
|
|
465
|
+
await self.store.delete(TaskResultEntry, key=task_id)
|
|
466
|
+
|
|
467
|
+
return result
|
|
530
468
|
|
|
531
469
|
async def set_result(self, task_id: str, value: Any):
|
|
532
470
|
"""
|
|
@@ -534,82 +472,7 @@ class TaskManager:
|
|
|
534
472
|
"""
|
|
535
473
|
return await self.store.set(TaskResultEntry, key=task_id, value=value)
|
|
536
474
|
|
|
537
|
-
def
|
|
538
|
-
"""
|
|
539
|
-
Recursively collect all task IDs in the task hierarchy
|
|
540
|
-
|
|
541
|
-
:param task: The root task to start collecting from
|
|
542
|
-
:return: Set of all task IDs in the hierarchy
|
|
543
|
-
"""
|
|
544
|
-
task_ids = {task.task_id}
|
|
545
|
-
|
|
546
|
-
if isinstance(task, MetaTask):
|
|
547
|
-
# Collect from args
|
|
548
|
-
for arg in task.args:
|
|
549
|
-
if isinstance(arg, BaseTask):
|
|
550
|
-
task_ids.update(self._collect_all_task_ids_in_hierarchy(arg))
|
|
551
|
-
|
|
552
|
-
# Collect from kwargs
|
|
553
|
-
for value in task.kwargs.values():
|
|
554
|
-
if isinstance(value, BaseTask):
|
|
555
|
-
task_ids.update(self._collect_all_task_ids_in_hierarchy(value))
|
|
556
|
-
|
|
557
|
-
return task_ids
|
|
558
|
-
|
|
559
|
-
async def _multicast_notification(self, task_id: str, messages: List[dict]):
|
|
560
|
-
"""
|
|
561
|
-
Send notifications to all task IDs that are related to a given task
|
|
562
|
-
|
|
563
|
-
:param task: the task the notifications are related to
|
|
564
|
-
:param messages: List of message dictionaries to send to all related tasks
|
|
565
|
-
"""
|
|
566
|
-
# prevent cancellation, we need the notifications to be sent
|
|
567
|
-
with CancelScope(shield=True):
|
|
568
|
-
# Find all PendingTasks that have the message_task_id in their hierarchy
|
|
569
|
-
tasks_to_notify = set()
|
|
570
|
-
|
|
571
|
-
for pending_task in self.tasks.values():
|
|
572
|
-
# Check if the message_task_id is in this PendingTask's hierarchy
|
|
573
|
-
task_ids_in_hierarchy = self._collect_all_task_ids_in_hierarchy(pending_task.task_def)
|
|
574
|
-
if task_id in task_ids_in_hierarchy:
|
|
575
|
-
tasks_to_notify.add(pending_task.task_id)
|
|
576
|
-
|
|
577
|
-
# Send notifications for all affected PendingTasks in parallel
|
|
578
|
-
if tasks_to_notify:
|
|
579
|
-
async with create_task_group() as task_tg:
|
|
580
|
-
for pending_task_id in tasks_to_notify:
|
|
581
|
-
if pending_task_id not in self.tasks:
|
|
582
|
-
continue
|
|
583
|
-
pending_task = self.tasks[pending_task_id]
|
|
584
|
-
task_tg.start_soon(self._send_notification_for_pending_task, pending_task, messages)
|
|
585
|
-
|
|
586
|
-
async def _send_notification_for_pending_task(self, pending_task: PendingTask, messages: List[dict]):
|
|
587
|
-
"""
|
|
588
|
-
Send notifications for a specific PendingTask
|
|
589
|
-
|
|
590
|
-
:param pending_task: The PendingTask to send notifications for
|
|
591
|
-
:param messages: The messages to send
|
|
592
|
-
"""
|
|
593
|
-
# Collect channels for this PendingTask
|
|
594
|
-
channels_to_notify = set(pending_task.notify_channels)
|
|
595
|
-
channels_to_notify.update(pending_task.task_def.notify_channels)
|
|
596
|
-
|
|
597
|
-
if not channels_to_notify:
|
|
598
|
-
return
|
|
599
|
-
|
|
600
|
-
# Send to all channels for this PendingTask in parallel
|
|
601
|
-
async def _send_to_channel(channel: str):
|
|
602
|
-
async with create_task_group() as channel_tg:
|
|
603
|
-
for message in messages:
|
|
604
|
-
# Create message with this PendingTask's task_id (if message has task_id)
|
|
605
|
-
message_for_task = {**message, 'task_id': pending_task.task_id} if 'task_id' in message else message
|
|
606
|
-
channel_tg.start_soon(self.ws_manager.send_message, channel, message_for_task)
|
|
607
|
-
|
|
608
|
-
async with create_task_group() as channel_tg:
|
|
609
|
-
for channel in channels_to_notify:
|
|
610
|
-
channel_tg.start_soon(_send_to_channel, channel)
|
|
611
|
-
|
|
612
|
-
async def _run_task_and_notify(self, task: BaseTask):
|
|
475
|
+
async def _run_task_and_notify(self, task: BaseTask, ws_channel: Optional[str]):
|
|
613
476
|
"""
|
|
614
477
|
Run the task to completion and notify the client of progress and completion
|
|
615
478
|
|
|
@@ -618,13 +481,23 @@ class TaskManager:
|
|
|
618
481
|
"""
|
|
619
482
|
cancel_scope = CancelScope()
|
|
620
483
|
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
pending_task.cancel_scope = cancel_scope
|
|
484
|
+
self.tasks[task.task_id].cancel_scope = cancel_scope
|
|
624
485
|
|
|
625
486
|
with cancel_scope:
|
|
626
487
|
eng_logger.info(f'TaskManager running task {task.task_id}')
|
|
627
488
|
|
|
489
|
+
async def notify_channels(*messages: dict):
|
|
490
|
+
"""
|
|
491
|
+
Notify the channels of the task's progress
|
|
492
|
+
"""
|
|
493
|
+
channels_to_notify = [*task.notify_channels]
|
|
494
|
+
if ws_channel:
|
|
495
|
+
channels_to_notify.append(ws_channel)
|
|
496
|
+
|
|
497
|
+
for channel in channels_to_notify:
|
|
498
|
+
for message in messages:
|
|
499
|
+
await self.ws_manager.send_message(channel, message)
|
|
500
|
+
|
|
628
501
|
# Create a memory object stream to capture messages from the tasks
|
|
629
502
|
send_stream, receive_stream = create_memory_object_stream[TaskMessage](math.inf)
|
|
630
503
|
|
|
@@ -632,41 +505,32 @@ class TaskManager:
|
|
|
632
505
|
async with receive_stream:
|
|
633
506
|
async for message in receive_stream:
|
|
634
507
|
if isinstance(message, TaskProgressUpdate):
|
|
635
|
-
#
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
508
|
+
# Notify the channels of the task's progress
|
|
509
|
+
await notify_channels(
|
|
510
|
+
{
|
|
511
|
+
'task_id': task.task_id,
|
|
512
|
+
'status': 'PROGRESS',
|
|
513
|
+
'progress': message.progress,
|
|
514
|
+
'message': message.message,
|
|
515
|
+
}
|
|
516
|
+
)
|
|
643
517
|
if isinstance(task, Task) and task.on_progress:
|
|
644
518
|
await run_user_handler(task.on_progress, args=(message,))
|
|
645
519
|
elif isinstance(message, TaskResult):
|
|
646
|
-
#
|
|
647
|
-
# 1. Direct-coordinated tasks: resolve via active_tasks registry
|
|
520
|
+
# Resolve the pending task related to the result
|
|
648
521
|
if message.task_id in self.tasks:
|
|
649
|
-
self.tasks[
|
|
650
|
-
|
|
651
|
-
# 2. Cache-coordinated tasks: resolve via cache store (CacheStore.set handles PendingTask resolution)
|
|
522
|
+
self.tasks[task.task_id].resolve(message.result)
|
|
523
|
+
# If the task has a cache key, update the cached value
|
|
652
524
|
if (
|
|
653
525
|
message.cache_key is not None
|
|
654
526
|
and message.reg_entry is not None
|
|
655
527
|
and message.reg_entry.cache is not None
|
|
656
528
|
):
|
|
657
529
|
await self.store.set(message.reg_entry, key=message.cache_key, value=message.result)
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
# Notify all PendingTasks that depend on this specific task
|
|
663
|
-
await self._multicast_notification(
|
|
664
|
-
task_id=message.task_id,
|
|
665
|
-
messages=[{'result': message.result, 'status': 'COMPLETE', 'task_id': message.task_id}],
|
|
530
|
+
# Notify the channels of the task's completion
|
|
531
|
+
await notify_channels(
|
|
532
|
+
{'result': message.result, 'status': 'COMPLETE', 'task_id': message.task_id}
|
|
666
533
|
)
|
|
667
|
-
|
|
668
|
-
# Remove the task from the registered tasks - it finished running
|
|
669
|
-
self.tasks.pop(message.task_id, None)
|
|
670
534
|
elif isinstance(message, TaskError):
|
|
671
535
|
# Fail the pending task related to the error
|
|
672
536
|
if message.task_id in self.tasks:
|
|
@@ -679,94 +543,46 @@ class TaskManager:
|
|
|
679
543
|
and message.reg_entry is not None
|
|
680
544
|
and message.reg_entry.cache is not None
|
|
681
545
|
):
|
|
682
|
-
await self.store.
|
|
683
|
-
|
|
684
|
-
# Notify all PendingTasks that depend on this specific task
|
|
685
|
-
error = get_error_for_channel(message.error)
|
|
686
|
-
await self._multicast_notification(
|
|
687
|
-
task_id=message.task_id,
|
|
688
|
-
messages=[
|
|
689
|
-
{'status': 'ERROR', 'task_id': message.task_id, 'error': error['error']},
|
|
690
|
-
error,
|
|
691
|
-
],
|
|
692
|
-
)
|
|
546
|
+
await self.store.set(message.reg_entry, key=message.cache_key, value=None)
|
|
693
547
|
|
|
694
|
-
|
|
695
|
-
|
|
548
|
+
try:
|
|
549
|
+
async with create_task_group() as tg:
|
|
550
|
+
# Handle incoming messages in parallel
|
|
551
|
+
tg.start_soon(handle_messages)
|
|
696
552
|
|
|
697
|
-
|
|
553
|
+
# Handle tasks that return other tasks
|
|
554
|
+
async with send_stream:
|
|
555
|
+
result = task
|
|
556
|
+
while isinstance(result, BaseTask):
|
|
557
|
+
result = await task.run(send_stream)
|
|
698
558
|
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
def handle_exception(err: ExceptionGroup):
|
|
702
|
-
nonlocal task_error
|
|
703
|
-
task_error = err
|
|
559
|
+
# Set final result
|
|
560
|
+
await self.set_result(task.task_id, result)
|
|
704
561
|
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
# Handle incoming messages in parallel
|
|
709
|
-
tg.start_soon(handle_messages)
|
|
710
|
-
|
|
711
|
-
# Handle tasks that return other tasks
|
|
712
|
-
async with send_stream:
|
|
713
|
-
result = task
|
|
714
|
-
while isinstance(result, BaseTask):
|
|
715
|
-
result = await task.run(send_stream)
|
|
716
|
-
|
|
717
|
-
# Notify any channels that need to be notified about the whole task being completed
|
|
718
|
-
await send_stream.send(
|
|
719
|
-
TaskResult(
|
|
720
|
-
task_id=task.task_id,
|
|
721
|
-
result=result,
|
|
722
|
-
cache_key=task.cache_key,
|
|
723
|
-
reg_entry=task.reg_entry,
|
|
724
|
-
)
|
|
725
|
-
)
|
|
726
|
-
eng_logger.info(f'TaskManager finished task {task.task_id}', {'result': result})
|
|
727
|
-
finally:
|
|
728
|
-
with CancelScope(shield=True):
|
|
729
|
-
# pyright: ignore[reportUnreachable]
|
|
730
|
-
if task_error is not None:
|
|
731
|
-
err = task_error
|
|
732
|
-
# Mark pending task as failed
|
|
733
|
-
pending_task.fail(err)
|
|
734
|
-
|
|
735
|
-
await self.set_result(task.task_id, {'error': str(err)})
|
|
736
|
-
|
|
737
|
-
# If the task has a cache key, set cached value to None
|
|
738
|
-
# This makes it so that the next request will recalculate the value rather than keep failing
|
|
739
|
-
if (
|
|
740
|
-
task.cache_key is not None
|
|
741
|
-
and task.reg_entry is not None
|
|
742
|
-
and task.reg_entry.cache is not None
|
|
743
|
-
):
|
|
744
|
-
await self.store.delete(task.reg_entry, key=task.cache_key)
|
|
745
|
-
|
|
746
|
-
# If this is a cancellation, ensure all tasks in the hierarchy are cancelled
|
|
747
|
-
if exception_group_contains(err_type=get_cancelled_exc_class(), group=err):
|
|
748
|
-
dev_logger.info('Task cancelled', {'task_id': task.task_id})
|
|
749
|
-
|
|
750
|
-
# Cancel any remaining tasks in the hierarchy that might still be running
|
|
751
|
-
await self._cancel_task_hierarchy(task, notify=True)
|
|
752
|
-
else:
|
|
753
|
-
dev_logger.error('Task failed', err, {'task_id': task.task_id})
|
|
754
|
-
|
|
755
|
-
error = get_error_for_channel(err)
|
|
756
|
-
message = {'status': 'ERROR', 'task_id': task.task_id, 'error': error['error']}
|
|
757
|
-
# Notify about this task failing, and a server broadcast error
|
|
758
|
-
await self._send_notification_for_pending_task(
|
|
759
|
-
pending_task=pending_task,
|
|
760
|
-
messages=[message, error],
|
|
761
|
-
)
|
|
762
|
-
# notify related tasks
|
|
763
|
-
await self._multicast_notification(
|
|
562
|
+
# Notify any channels that need to be notified about the whole task being completed
|
|
563
|
+
await send_stream.send(
|
|
564
|
+
TaskResult(
|
|
764
565
|
task_id=task.task_id,
|
|
765
|
-
|
|
566
|
+
result=result,
|
|
567
|
+
cache_key=task.cache_key,
|
|
568
|
+
reg_entry=task.reg_entry,
|
|
766
569
|
)
|
|
570
|
+
)
|
|
571
|
+
eng_logger.info(f'TaskManager finished task {task.task_id}', {'result': result})
|
|
572
|
+
except (Exception, ExceptionGroup) as err:
|
|
573
|
+
err = resolve_exception_group(err)
|
|
574
|
+
|
|
575
|
+
# Mark pending task as failed
|
|
576
|
+
self.tasks[task.task_id].fail(err)
|
|
767
577
|
|
|
768
|
-
|
|
769
|
-
|
|
578
|
+
dev_logger.error('Task failed', err, {'task_id': task.task_id})
|
|
579
|
+
await self.set_result(task.task_id, {'error': str(err)})
|
|
580
|
+
|
|
581
|
+
# Notify any channels that need to be notified
|
|
582
|
+
await notify_channels({'status': 'ERROR', 'task_id': task.task_id}, get_error_for_channel())
|
|
583
|
+
finally:
|
|
584
|
+
# Remove the task from the running tasks
|
|
585
|
+
self.tasks.pop(task.task_id, None)
|
|
770
586
|
|
|
771
587
|
# Make sure streams are closed
|
|
772
588
|
with move_on_after(3, shield=True):
|
|
@@ -774,7 +590,7 @@ class TaskManager:
|
|
|
774
590
|
await receive_stream.aclose()
|
|
775
591
|
|
|
776
592
|
|
|
777
|
-
TaskResultEntry = CachedRegistryEntry(uid='task-results', cache=
|
|
593
|
+
TaskResultEntry = CachedRegistryEntry(uid='task-results', cache=Cache.Policy.KeepAll())
|
|
778
594
|
"""
|
|
779
595
|
Global registry entry for task results.
|
|
780
596
|
This is global because task ids are unique and accessed one time only so it's effectively a one-time use random key.
|