dara-core 1.21.16__py3-none-any.whl → 1.21.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dara/core/auth/base.py +5 -5
- dara/core/auth/basic.py +3 -3
- dara/core/auth/definitions.py +13 -14
- dara/core/auth/routes.py +7 -5
- dara/core/auth/utils.py +11 -10
- dara/core/base_definitions.py +30 -36
- dara/core/cli.py +7 -8
- dara/core/configuration.py +51 -58
- dara/core/css.py +2 -2
- dara/core/data_utils.py +12 -17
- dara/core/defaults.py +3 -3
- dara/core/definitions.py +58 -63
- dara/core/http.py +4 -4
- dara/core/interactivity/actions.py +34 -42
- dara/core/interactivity/any_data_variable.py +1 -1
- dara/core/interactivity/any_variable.py +6 -5
- dara/core/interactivity/client_variable.py +1 -2
- dara/core/interactivity/condition.py +2 -2
- dara/core/interactivity/data_variable.py +2 -4
- dara/core/interactivity/derived_data_variable.py +7 -10
- dara/core/interactivity/derived_variable.py +45 -51
- dara/core/interactivity/filtering.py +19 -19
- dara/core/interactivity/loop_variable.py +2 -4
- dara/core/interactivity/non_data_variable.py +1 -1
- dara/core/interactivity/plain_variable.py +22 -18
- dara/core/interactivity/server_variable.py +13 -15
- dara/core/interactivity/state_variable.py +4 -5
- dara/core/interactivity/switch_variable.py +16 -16
- dara/core/interactivity/tabular_variable.py +3 -3
- dara/core/interactivity/url_variable.py +3 -3
- dara/core/internal/cache_store/cache_store.py +6 -6
- dara/core/internal/cache_store/keep_all.py +3 -3
- dara/core/internal/cache_store/lru.py +8 -8
- dara/core/internal/cache_store/ttl.py +4 -4
- dara/core/internal/custom_response.py +3 -3
- dara/core/internal/dependency_resolution.py +6 -10
- dara/core/internal/devtools.py +2 -3
- dara/core/internal/download.py +5 -6
- dara/core/internal/encoder_registry.py +7 -11
- dara/core/internal/execute_action.py +5 -5
- dara/core/internal/hashing.py +1 -2
- dara/core/internal/import_discovery.py +7 -9
- dara/core/internal/normalization.py +12 -15
- dara/core/internal/pandas_utils.py +6 -6
- dara/core/internal/pool/channel.py +3 -4
- dara/core/internal/pool/definitions.py +9 -9
- dara/core/internal/pool/task_pool.py +8 -8
- dara/core/internal/pool/utils.py +4 -3
- dara/core/internal/pool/worker.py +3 -3
- dara/core/internal/registries.py +4 -4
- dara/core/internal/registry.py +3 -3
- dara/core/internal/registry_lookup.py +4 -4
- dara/core/internal/routing.py +34 -37
- dara/core/internal/scheduler.py +8 -8
- dara/core/internal/settings.py +1 -2
- dara/core/internal/store.py +9 -9
- dara/core/internal/tasks.py +30 -30
- dara/core/internal/utils.py +9 -15
- dara/core/internal/websocket.py +18 -18
- dara/core/js_tooling/js_utils.py +19 -19
- dara/core/logging.py +13 -13
- dara/core/main.py +11 -6
- dara/core/metrics/cache.py +2 -4
- dara/core/persistence.py +19 -25
- dara/core/router/compat.py +1 -3
- dara/core/router/components.py +10 -10
- dara/core/router/dependency_graph.py +2 -4
- dara/core/router/router.py +43 -42
- dara/core/umd/dara.core.umd.cjs +44 -197
- dara/core/visual/components/dynamic_component.py +1 -3
- dara/core/visual/components/fallback.py +3 -3
- dara/core/visual/components/for_cmp.py +5 -5
- dara/core/visual/components/menu.py +1 -3
- dara/core/visual/components/router_content.py +1 -3
- dara/core/visual/components/sidebar_frame.py +8 -10
- dara/core/visual/components/theme_provider.py +3 -3
- dara/core/visual/components/topbar_frame.py +8 -10
- dara/core/visual/css/__init__.py +277 -277
- dara/core/visual/dynamic_component.py +18 -22
- dara/core/visual/progress_updater.py +1 -1
- dara/core/visual/template.py +10 -12
- dara/core/visual/themes/definitions.py +46 -46
- {dara_core-1.21.16.dist-info → dara_core-1.21.18.dist-info}/METADATA +13 -13
- dara_core-1.21.18.dist-info/RECORD +127 -0
- dara_core-1.21.16.dist-info/RECORD +0 -127
- {dara_core-1.21.16.dist-info → dara_core-1.21.18.dist-info}/LICENSE +0 -0
- {dara_core-1.21.16.dist-info → dara_core-1.21.18.dist-info}/WHEEL +0 -0
- {dara_core-1.21.16.dist-info → dara_core-1.21.18.dist-info}/entry_points.txt +0 -0
dara/core/internal/registry.py
CHANGED
|
@@ -18,7 +18,7 @@ limitations under the License.
|
|
|
18
18
|
import copy
|
|
19
19
|
from collections.abc import MutableMapping
|
|
20
20
|
from enum import Enum
|
|
21
|
-
from typing import Generic,
|
|
21
|
+
from typing import Generic, TypeVar
|
|
22
22
|
|
|
23
23
|
from dara.core.metrics import CACHE_METRICS_TRACKER, total_size
|
|
24
24
|
|
|
@@ -57,8 +57,8 @@ class Registry(Generic[T]):
|
|
|
57
57
|
def __init__(
|
|
58
58
|
self,
|
|
59
59
|
name: RegistryType,
|
|
60
|
-
initial_registry:
|
|
61
|
-
allow_duplicates:
|
|
60
|
+
initial_registry: MutableMapping[str, T] | None = None,
|
|
61
|
+
allow_duplicates: bool | None = True,
|
|
62
62
|
):
|
|
63
63
|
"""
|
|
64
64
|
:param name: human readable name of the registry; used for metrics
|
|
@@ -15,8 +15,8 @@ See the License for the specific language governing permissions and
|
|
|
15
15
|
limitations under the License.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
-
from collections.abc import Coroutine
|
|
19
|
-
from typing import
|
|
18
|
+
from collections.abc import Callable, Coroutine
|
|
19
|
+
from typing import Literal, TypeVar
|
|
20
20
|
|
|
21
21
|
from dara.core.internal.registry import Registry, RegistryType
|
|
22
22
|
from dara.core.internal.utils import async_dedupe
|
|
@@ -31,7 +31,7 @@ RegistryLookupKey = Literal[
|
|
|
31
31
|
RegistryType.BACKEND_STORE,
|
|
32
32
|
RegistryType.DOWNLOAD_CODE,
|
|
33
33
|
]
|
|
34
|
-
CustomRegistryLookup =
|
|
34
|
+
CustomRegistryLookup = dict[RegistryLookupKey, Callable[[str], Coroutine]]
|
|
35
35
|
|
|
36
36
|
RegistryType = TypeVar('RegistryType')
|
|
37
37
|
|
|
@@ -41,7 +41,7 @@ class RegistryLookup:
|
|
|
41
41
|
Manages registry Lookup.
|
|
42
42
|
"""
|
|
43
43
|
|
|
44
|
-
def __init__(self, handlers:
|
|
44
|
+
def __init__(self, handlers: CustomRegistryLookup | None = None):
|
|
45
45
|
if handlers is None:
|
|
46
46
|
handlers = {}
|
|
47
47
|
self.handlers = handlers
|
dara/core/internal/routing.py
CHANGED
|
@@ -20,13 +20,14 @@ import json
|
|
|
20
20
|
import math
|
|
21
21
|
import os
|
|
22
22
|
import traceback
|
|
23
|
-
from collections.abc import Mapping
|
|
23
|
+
from collections.abc import Callable, Mapping
|
|
24
24
|
from functools import wraps
|
|
25
25
|
from importlib.metadata import version
|
|
26
|
-
from typing import Annotated, Any,
|
|
26
|
+
from typing import Annotated, Any, Literal
|
|
27
27
|
from urllib.parse import unquote
|
|
28
28
|
|
|
29
29
|
import anyio
|
|
30
|
+
import backoff
|
|
30
31
|
from anyio.streams.memory import MemoryObjectSendStream
|
|
31
32
|
from fastapi import (
|
|
32
33
|
APIRouter,
|
|
@@ -37,6 +38,7 @@ from fastapi import (
|
|
|
37
38
|
Form,
|
|
38
39
|
HTTPException,
|
|
39
40
|
Path,
|
|
41
|
+
Query,
|
|
40
42
|
Response,
|
|
41
43
|
UploadFile,
|
|
42
44
|
)
|
|
@@ -213,20 +215,6 @@ async def get_download(code: str):
|
|
|
213
215
|
raise ValueError('Invalid or expired download code') from e
|
|
214
216
|
|
|
215
217
|
|
|
216
|
-
@core_api_router.get('/components/{name}/definition', dependencies=[Depends(verify_session)])
|
|
217
|
-
async def get_component_definition(name: str):
|
|
218
|
-
"""
|
|
219
|
-
Attempt to refetch a component definition from the backend.
|
|
220
|
-
This is used when a component isn't immediately available in the initial registry,
|
|
221
|
-
e.g. when it was added by a py_component.
|
|
222
|
-
|
|
223
|
-
:param name: the name of component
|
|
224
|
-
"""
|
|
225
|
-
registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
|
|
226
|
-
component = await registry_mgr.get(component_registry, name)
|
|
227
|
-
return component.model_dump(exclude={'func'})
|
|
228
|
-
|
|
229
|
-
|
|
230
218
|
class ComponentRequestBody(BaseModel):
|
|
231
219
|
# Dynamic kwarg values
|
|
232
220
|
values: NormalizedPayload[Mapping[str, Any]]
|
|
@@ -243,7 +231,16 @@ async def get_component(component: str, body: ComponentRequestBody):
|
|
|
243
231
|
store: CacheStore = utils_registry.get('Store')
|
|
244
232
|
task_mgr: TaskManager = utils_registry.get('TaskManager')
|
|
245
233
|
registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
|
|
246
|
-
|
|
234
|
+
|
|
235
|
+
# Retry 5 times with a constant backoff of 1 s with 0-1 jitter
|
|
236
|
+
@backoff.on_exception(backoff.constant, ValueError, max_tries=5, jitter=backoff.full_jitter)
|
|
237
|
+
async def _get_component():
|
|
238
|
+
try:
|
|
239
|
+
return await registry_mgr.get(component_registry, component)
|
|
240
|
+
except Exception as e:
|
|
241
|
+
raise ValueError(f'Could now resolve component {component}. Was it registered in the app?') from e
|
|
242
|
+
|
|
243
|
+
comp_def = await _get_component()
|
|
247
244
|
|
|
248
245
|
if isinstance(comp_def, PyComponentDef):
|
|
249
246
|
static_kwargs = await registry_mgr.get(static_kwargs_registry, body.uid)
|
|
@@ -295,11 +292,11 @@ async def get_latest_derived_variable(uid: str):
|
|
|
295
292
|
|
|
296
293
|
|
|
297
294
|
class TabularRequestBody(BaseModel):
|
|
298
|
-
filters:
|
|
295
|
+
filters: FilterQuery | None = None
|
|
299
296
|
ws_channel: str
|
|
300
|
-
dv_values:
|
|
297
|
+
dv_values: NormalizedPayload[list[Any]] | None = None
|
|
301
298
|
"""DerivedVariable values if variable is a DerivedVariable"""
|
|
302
|
-
force_key:
|
|
299
|
+
force_key: str | None = None
|
|
303
300
|
"""Optional force key if variable is a DerivedVariable and a recalculation is forced"""
|
|
304
301
|
|
|
305
302
|
|
|
@@ -307,10 +304,10 @@ class TabularRequestBody(BaseModel):
|
|
|
307
304
|
async def get_tabular_variable(
|
|
308
305
|
uid: str,
|
|
309
306
|
body: TabularRequestBody,
|
|
310
|
-
offset:
|
|
311
|
-
limit:
|
|
312
|
-
order_by:
|
|
313
|
-
index:
|
|
307
|
+
offset: int | None = None,
|
|
308
|
+
limit: int | None = None,
|
|
309
|
+
order_by: str | None = None,
|
|
310
|
+
index: str | None = None,
|
|
314
311
|
):
|
|
315
312
|
"""
|
|
316
313
|
Generic endpoint for getting tabular data from a variable.
|
|
@@ -359,9 +356,9 @@ async def get_server_variable_sequence(
|
|
|
359
356
|
|
|
360
357
|
@core_api_router.post('/data/upload', dependencies=[Depends(verify_session)])
|
|
361
358
|
async def upload_data(
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
359
|
+
data: Annotated[UploadFile, File()],
|
|
360
|
+
resolver_id: Annotated[str | None, Form()] = None,
|
|
361
|
+
data_uid: Annotated[str | None, Query()] = None,
|
|
365
362
|
):
|
|
366
363
|
"""
|
|
367
364
|
Upload endpoint.
|
|
@@ -391,8 +388,8 @@ async def upload_data(
|
|
|
391
388
|
|
|
392
389
|
|
|
393
390
|
class DerivedStateRequestBody(BaseModel):
|
|
394
|
-
values: NormalizedPayload[
|
|
395
|
-
force_key:
|
|
391
|
+
values: NormalizedPayload[list[Any]]
|
|
392
|
+
force_key: str | None = None
|
|
396
393
|
ws_channel: str
|
|
397
394
|
|
|
398
395
|
|
|
@@ -448,7 +445,7 @@ async def read_backend_store(store_uid: str):
|
|
|
448
445
|
|
|
449
446
|
|
|
450
447
|
@core_api_router.post('/store', dependencies=[Depends(verify_session)])
|
|
451
|
-
async def sync_backend_store(ws_channel: str
|
|
448
|
+
async def sync_backend_store(ws_channel: Annotated[str, Body()], values: Annotated[dict[str, Any], Body()]):
|
|
452
449
|
registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
|
|
453
450
|
|
|
454
451
|
async def _write(store_uid: str, value: Any):
|
|
@@ -519,7 +516,7 @@ class ActionPayload(BaseModel):
|
|
|
519
516
|
|
|
520
517
|
class DerivedVariablePayload(BaseModel):
|
|
521
518
|
uid: str
|
|
522
|
-
values: NormalizedPayload[
|
|
519
|
+
values: NormalizedPayload[list[Any]]
|
|
523
520
|
|
|
524
521
|
|
|
525
522
|
class PyComponentPayload(BaseModel):
|
|
@@ -529,11 +526,11 @@ class PyComponentPayload(BaseModel):
|
|
|
529
526
|
|
|
530
527
|
|
|
531
528
|
class RouteDataRequestBody(BaseModel):
|
|
532
|
-
action_payloads:
|
|
533
|
-
derived_variable_payloads:
|
|
534
|
-
py_component_payloads:
|
|
529
|
+
action_payloads: list[ActionPayload] = Field(default_factory=list)
|
|
530
|
+
derived_variable_payloads: list[DerivedVariablePayload] = Field(default_factory=list)
|
|
531
|
+
py_component_payloads: list[PyComponentPayload] = Field(default_factory=list)
|
|
535
532
|
ws_channel: str
|
|
536
|
-
params:
|
|
533
|
+
params: dict[str, str] = Field(default_factory=dict)
|
|
537
534
|
|
|
538
535
|
|
|
539
536
|
class Result(BaseModel):
|
|
@@ -561,7 +558,7 @@ class PyComponentChunk(BaseModel):
|
|
|
561
558
|
result: Result
|
|
562
559
|
|
|
563
560
|
|
|
564
|
-
Chunk =
|
|
561
|
+
Chunk = DerivedVariableChunk | PyComponentChunk
|
|
565
562
|
|
|
566
563
|
|
|
567
564
|
def create_loader_route(config: Configuration, app: FastAPI):
|
|
@@ -577,7 +574,7 @@ def create_loader_route(config: Configuration, app: FastAPI):
|
|
|
577
574
|
if route_data is None:
|
|
578
575
|
raise HTTPException(status_code=404, detail=f'Route {route_id} not found')
|
|
579
576
|
|
|
580
|
-
action_results:
|
|
577
|
+
action_results: dict[str, Any] = {}
|
|
581
578
|
|
|
582
579
|
if len(body.action_payloads) > 0:
|
|
583
580
|
store: CacheStore = utils_registry.get('Store')
|
dara/core/internal/scheduler.py
CHANGED
|
@@ -20,19 +20,19 @@ from datetime import datetime
|
|
|
20
20
|
from multiprocessing import get_context
|
|
21
21
|
from multiprocessing.process import BaseProcess
|
|
22
22
|
from pickle import PicklingError
|
|
23
|
-
from typing import Any,
|
|
23
|
+
from typing import Any, cast
|
|
24
24
|
|
|
25
25
|
from croniter import croniter
|
|
26
26
|
from pydantic import BaseModel, field_validator
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class ScheduledJob(BaseModel):
|
|
30
|
-
interval:
|
|
30
|
+
interval: int | list[int]
|
|
31
31
|
continue_running: bool = True
|
|
32
32
|
first_execution: bool = True
|
|
33
33
|
run_once: bool
|
|
34
34
|
|
|
35
|
-
def __init__(self, interval:
|
|
35
|
+
def __init__(self, interval: int | list, run_once=False, **kwargs):
|
|
36
36
|
"""
|
|
37
37
|
Creates a ScheduledJob object
|
|
38
38
|
|
|
@@ -126,7 +126,7 @@ class CronScheduledJob(ScheduledJob):
|
|
|
126
126
|
class TimeScheduledJob(ScheduledJob):
|
|
127
127
|
job_time: datetime
|
|
128
128
|
|
|
129
|
-
def __init__(self, interval:
|
|
129
|
+
def __init__(self, interval: int | list, job_time: str, run_once=False):
|
|
130
130
|
"""
|
|
131
131
|
Creates a TimeScheduledJob object
|
|
132
132
|
|
|
@@ -172,9 +172,9 @@ class ScheduledJobFactory(BaseModel):
|
|
|
172
172
|
:param run_once: Whether the job should be run only once
|
|
173
173
|
"""
|
|
174
174
|
|
|
175
|
-
interval:
|
|
175
|
+
interval: int | list
|
|
176
176
|
continue_running: bool = True
|
|
177
|
-
weekday:
|
|
177
|
+
weekday: datetime | None = None
|
|
178
178
|
run_once: bool
|
|
179
179
|
|
|
180
180
|
@field_validator('weekday', mode='before')
|
|
@@ -197,7 +197,7 @@ class ScheduledJobFactory(BaseModel):
|
|
|
197
197
|
# If the job is scheduled to execute on a weekly basis
|
|
198
198
|
if self.weekday is not None:
|
|
199
199
|
# Set 2 intervals, where the first interval is the time from now until the first execution
|
|
200
|
-
interval = [(self.weekday - datetime.utcnow()).seconds, self.interval]
|
|
200
|
+
interval: list | int = [(self.weekday - datetime.utcnow()).seconds, self.interval]
|
|
201
201
|
else:
|
|
202
202
|
interval = self.interval
|
|
203
203
|
job = TimeScheduledJob(interval, job_time, run_once=self.run_once)
|
|
@@ -214,7 +214,7 @@ class ScheduledJobFactory(BaseModel):
|
|
|
214
214
|
"""
|
|
215
215
|
if self.weekday is not None:
|
|
216
216
|
# Set 2 intervals, where the first interval is the time from now until the first execution
|
|
217
|
-
interval = [(self.weekday - datetime.utcnow()).seconds, self.interval]
|
|
217
|
+
interval: list | int = [(self.weekday - datetime.utcnow()).seconds, self.interval]
|
|
218
218
|
else:
|
|
219
219
|
interval = self.interval
|
|
220
220
|
job = ScheduledJob(interval, run_once=self.run_once)
|
dara/core/internal/settings.py
CHANGED
|
@@ -18,7 +18,6 @@ limitations under the License.
|
|
|
18
18
|
import os
|
|
19
19
|
from functools import lru_cache
|
|
20
20
|
from secrets import token_hex
|
|
21
|
-
from typing import List, Optional
|
|
22
21
|
|
|
23
22
|
from dotenv import dotenv_values
|
|
24
23
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
@@ -44,7 +43,7 @@ class Settings(BaseSettings):
|
|
|
44
43
|
sso_groups: str = ''
|
|
45
44
|
sso_jwt_algo: str = 'ES256'
|
|
46
45
|
sso_verify_audience: bool = False
|
|
47
|
-
sso_extra_audience:
|
|
46
|
+
sso_extra_audience: list[str] | None = None
|
|
48
47
|
model_config = SettingsConfigDict(env_file='.env', extra='allow')
|
|
49
48
|
|
|
50
49
|
|
dara/core/internal/store.py
CHANGED
|
@@ -15,7 +15,7 @@ See the License for the specific language governing permissions and
|
|
|
15
15
|
limitations under the License.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
-
from typing import Any
|
|
18
|
+
from typing import Any
|
|
19
19
|
|
|
20
20
|
from dara.core.base_definitions import CacheType, PendingTask
|
|
21
21
|
from dara.core.internal.utils import get_cache_scope
|
|
@@ -34,9 +34,9 @@ class Store:
|
|
|
34
34
|
def __init__(self):
|
|
35
35
|
# This dict is the main store of values, the first level of keys is the cache type and the second level are the
|
|
36
36
|
# value keys. The root key is used for any non-session/user dependant keys that are added.
|
|
37
|
-
self._store:
|
|
37
|
+
self._store: dict[str, dict[str, Any]] = {'global': {}}
|
|
38
38
|
|
|
39
|
-
def get(self, key: str, cache_type:
|
|
39
|
+
def get(self, key: str, cache_type: CacheType | None = CacheType.GLOBAL) -> Any:
|
|
40
40
|
"""
|
|
41
41
|
Get a given key from the store and optionally pull it from the specified cache store or the global one
|
|
42
42
|
|
|
@@ -47,7 +47,7 @@ class Store:
|
|
|
47
47
|
cache_key = get_cache_scope(cache_type)
|
|
48
48
|
return self._store.get(cache_key, {}).get(key)
|
|
49
49
|
|
|
50
|
-
async def get_or_wait(self, key: str, cache_type:
|
|
50
|
+
async def get_or_wait(self, key: str, cache_type: CacheType | None = CacheType.GLOBAL) -> Any:
|
|
51
51
|
"""
|
|
52
52
|
Wait for the given key to not be pending and then return the value. Optionally pull it from the specified cache specific
|
|
53
53
|
store or the global one
|
|
@@ -67,8 +67,8 @@ class Store:
|
|
|
67
67
|
self,
|
|
68
68
|
key: str,
|
|
69
69
|
value: Any,
|
|
70
|
-
cache_type:
|
|
71
|
-
error:
|
|
70
|
+
cache_type: CacheType | None = CacheType.GLOBAL,
|
|
71
|
+
error: Exception | None = None,
|
|
72
72
|
):
|
|
73
73
|
"""
|
|
74
74
|
Set the value of a given key in the store. the cache_type flag can be used to optionally pull the value from
|
|
@@ -86,7 +86,7 @@ class Store:
|
|
|
86
86
|
|
|
87
87
|
self._store[cache_key][key] = value
|
|
88
88
|
|
|
89
|
-
def set_pending_task(self, key: str, pending_task: PendingTask, cache_type:
|
|
89
|
+
def set_pending_task(self, key: str, pending_task: PendingTask, cache_type: CacheType | None = CacheType.GLOBAL):
|
|
90
90
|
"""
|
|
91
91
|
Store a pending task state for a given key in the store. This will trigger the async behavior of the get call if subsequent
|
|
92
92
|
requests ask for the same key. The PendingTask will be resolved once the underlying task is completed.
|
|
@@ -102,7 +102,7 @@ class Store:
|
|
|
102
102
|
|
|
103
103
|
self._store[cache_key][key] = pending_task
|
|
104
104
|
|
|
105
|
-
def remove_starting_with(self, start: str, cache_type:
|
|
105
|
+
def remove_starting_with(self, start: str, cache_type: CacheType | None = CacheType.GLOBAL):
|
|
106
106
|
"""
|
|
107
107
|
Remove any entries stored under keys starting with given string
|
|
108
108
|
|
|
@@ -134,7 +134,7 @@ class Store:
|
|
|
134
134
|
if not isinstance(cache_type_store[key], PendingTask):
|
|
135
135
|
cache_type_store.pop(key)
|
|
136
136
|
|
|
137
|
-
def list(self, cache_type:
|
|
137
|
+
def list(self, cache_type: CacheType | None = CacheType.GLOBAL) -> list[str]:
|
|
138
138
|
"""
|
|
139
139
|
List all keys in a specified cache store. Listed the global store unless cache_type is not None
|
|
140
140
|
|
dara/core/internal/tasks.py
CHANGED
|
@@ -18,8 +18,8 @@ limitations under the License.
|
|
|
18
18
|
import contextlib
|
|
19
19
|
import inspect
|
|
20
20
|
import math
|
|
21
|
-
from collections.abc import Awaitable
|
|
22
|
-
from typing import Any,
|
|
21
|
+
from collections.abc import Awaitable, Callable
|
|
22
|
+
from typing import Any, cast, overload
|
|
23
23
|
|
|
24
24
|
from anyio import (
|
|
25
25
|
BrokenResourceError,
|
|
@@ -68,13 +68,13 @@ class Task(BaseTask):
|
|
|
68
68
|
def __init__(
|
|
69
69
|
self,
|
|
70
70
|
func: Callable,
|
|
71
|
-
args:
|
|
72
|
-
kwargs:
|
|
73
|
-
reg_entry:
|
|
74
|
-
notify_channels:
|
|
75
|
-
cache_key:
|
|
76
|
-
task_id:
|
|
77
|
-
on_progress:
|
|
71
|
+
args: list[Any] | None = None,
|
|
72
|
+
kwargs: dict[str, Any] | None = None,
|
|
73
|
+
reg_entry: CachedRegistryEntry | None = None,
|
|
74
|
+
notify_channels: list[str] | None = None,
|
|
75
|
+
cache_key: str | None = None,
|
|
76
|
+
task_id: str | None = None,
|
|
77
|
+
on_progress: Callable[[TaskProgressUpdate], None | Awaitable[None]] | None = None,
|
|
78
78
|
):
|
|
79
79
|
"""
|
|
80
80
|
:param func: The function to execute within the process
|
|
@@ -122,7 +122,7 @@ class Task(BaseTask):
|
|
|
122
122
|
|
|
123
123
|
return func.__name__
|
|
124
124
|
|
|
125
|
-
async def run(self, send_stream:
|
|
125
|
+
async def run(self, send_stream: MemoryObjectSendStream[TaskMessage] | None = None) -> Any:
|
|
126
126
|
"""
|
|
127
127
|
Run the task asynchronously, and await its' end.
|
|
128
128
|
|
|
@@ -200,13 +200,13 @@ class MetaTask(BaseTask):
|
|
|
200
200
|
def __init__(
|
|
201
201
|
self,
|
|
202
202
|
process_result: Callable[..., Any],
|
|
203
|
-
args:
|
|
204
|
-
kwargs:
|
|
205
|
-
reg_entry:
|
|
206
|
-
notify_channels:
|
|
203
|
+
args: list[Any] | None = None,
|
|
204
|
+
kwargs: dict[str, Any] | None = None,
|
|
205
|
+
reg_entry: CachedRegistryEntry | None = None,
|
|
206
|
+
notify_channels: list[str] | None = None,
|
|
207
207
|
process_as_task: bool = False,
|
|
208
|
-
cache_key:
|
|
209
|
-
task_id:
|
|
208
|
+
cache_key: str | None = None,
|
|
209
|
+
task_id: str | None = None,
|
|
210
210
|
):
|
|
211
211
|
"""
|
|
212
212
|
:param process result: A function to process the result of the other tasks
|
|
@@ -224,13 +224,13 @@ class MetaTask(BaseTask):
|
|
|
224
224
|
self.kwargs = kwargs if kwargs is not None else {}
|
|
225
225
|
self.notify_channels = notify_channels if notify_channels is not None else []
|
|
226
226
|
self.process_as_task = process_as_task
|
|
227
|
-
self.cancel_scope:
|
|
227
|
+
self.cancel_scope: CancelScope | None = None
|
|
228
228
|
self.cache_key = cache_key
|
|
229
229
|
self.reg_entry = reg_entry
|
|
230
230
|
|
|
231
231
|
super().__init__(task_id)
|
|
232
232
|
|
|
233
|
-
async def run(self, send_stream:
|
|
233
|
+
async def run(self, send_stream: MemoryObjectSendStream[TaskMessage] | None = None):
|
|
234
234
|
"""
|
|
235
235
|
Run any tasks found in the arguments to completion, collect the results and then call the process result
|
|
236
236
|
function as a further task with a resultant arguments
|
|
@@ -241,7 +241,7 @@ class MetaTask(BaseTask):
|
|
|
241
241
|
|
|
242
242
|
try:
|
|
243
243
|
with self.cancel_scope:
|
|
244
|
-
tasks:
|
|
244
|
+
tasks: list[BaseTask] = []
|
|
245
245
|
|
|
246
246
|
# Collect up the tasks that need to be run and kick them off without awaiting them.
|
|
247
247
|
tasks.extend(x for x in self.args if isinstance(x, BaseTask))
|
|
@@ -250,7 +250,7 @@ class MetaTask(BaseTask):
|
|
|
250
250
|
eng_logger.info(f'MetaTask {self.task_id} running sub-tasks', {'task_ids': [x.task_id for x in tasks]})
|
|
251
251
|
|
|
252
252
|
# Wait for all tasks to complete
|
|
253
|
-
results:
|
|
253
|
+
results: dict[str, Any] = {}
|
|
254
254
|
|
|
255
255
|
async def _run_and_capture_result(task: BaseTask):
|
|
256
256
|
"""
|
|
@@ -382,7 +382,7 @@ class TaskManager:
|
|
|
382
382
|
"""
|
|
383
383
|
|
|
384
384
|
def __init__(self, task_group: TaskGroup, ws_manager: WebsocketManager, store: CacheStore):
|
|
385
|
-
self.tasks:
|
|
385
|
+
self.tasks: dict[str, PendingTask] = {}
|
|
386
386
|
self.task_group = task_group
|
|
387
387
|
self.ws_manager = ws_manager
|
|
388
388
|
self.store = store
|
|
@@ -396,12 +396,12 @@ class TaskManager:
|
|
|
396
396
|
return pending_task
|
|
397
397
|
|
|
398
398
|
@overload
|
|
399
|
-
async def run_task(self, task: PendingTask, ws_channel:
|
|
399
|
+
async def run_task(self, task: PendingTask, ws_channel: str | None = None) -> Any: ...
|
|
400
400
|
|
|
401
401
|
@overload
|
|
402
|
-
async def run_task(self, task: BaseTask, ws_channel:
|
|
402
|
+
async def run_task(self, task: BaseTask, ws_channel: str | None = None) -> PendingTask: ...
|
|
403
403
|
|
|
404
|
-
async def run_task(self, task: BaseTask, ws_channel:
|
|
404
|
+
async def run_task(self, task: BaseTask, ws_channel: str | None = None):
|
|
405
405
|
"""
|
|
406
406
|
Run a task and store it in the tasks dict
|
|
407
407
|
|
|
@@ -438,7 +438,7 @@ class TaskManager:
|
|
|
438
438
|
|
|
439
439
|
return pending_task
|
|
440
440
|
|
|
441
|
-
async def _cancel_tasks(self, task_ids:
|
|
441
|
+
async def _cancel_tasks(self, task_ids: list[str], notify: bool = True):
|
|
442
442
|
"""
|
|
443
443
|
Cancel a list of tasks
|
|
444
444
|
|
|
@@ -534,7 +534,7 @@ class TaskManager:
|
|
|
534
534
|
"""
|
|
535
535
|
return await self.store.set(TaskResultEntry, key=task_id, value=value)
|
|
536
536
|
|
|
537
|
-
def _collect_all_task_ids_in_hierarchy(self, task: BaseTask) ->
|
|
537
|
+
def _collect_all_task_ids_in_hierarchy(self, task: BaseTask) -> set[str]:
|
|
538
538
|
"""
|
|
539
539
|
Recursively collect all task IDs in the task hierarchy
|
|
540
540
|
|
|
@@ -556,7 +556,7 @@ class TaskManager:
|
|
|
556
556
|
|
|
557
557
|
return task_ids
|
|
558
558
|
|
|
559
|
-
async def _multicast_notification(self, task_id: str, messages:
|
|
559
|
+
async def _multicast_notification(self, task_id: str, messages: list[dict], variable_task_id: bool = True):
|
|
560
560
|
"""
|
|
561
561
|
Send notifications to all task IDs that are related to a given task
|
|
562
562
|
|
|
@@ -585,7 +585,7 @@ class TaskManager:
|
|
|
585
585
|
pending_task = self.tasks[pending_task_id]
|
|
586
586
|
task_tg.start_soon(self._send_notification_for_task, pending_task, messages, variable_task_id)
|
|
587
587
|
|
|
588
|
-
async def _send_notification_for_task(self, task: BaseTask, messages:
|
|
588
|
+
async def _send_notification_for_task(self, task: BaseTask, messages: list[dict], variable_task_id: bool = True):
|
|
589
589
|
"""
|
|
590
590
|
Send notifications for a specific PendingTask
|
|
591
591
|
|
|
@@ -713,7 +713,7 @@ class TaskManager:
|
|
|
713
713
|
# Remove the task from the registered tasks - it finished running
|
|
714
714
|
self.tasks.pop(message.task_id, None)
|
|
715
715
|
|
|
716
|
-
task_error:
|
|
716
|
+
task_error: ExceptionGroup | None = None
|
|
717
717
|
|
|
718
718
|
# ExceptionGroup handler can't be async so we just mark the task as errored
|
|
719
719
|
# and run the async handler in the finally block
|
|
@@ -746,7 +746,7 @@ class TaskManager:
|
|
|
746
746
|
finally:
|
|
747
747
|
with CancelScope(shield=True):
|
|
748
748
|
# cast explicitly as otherwise pyright thinks it's always None here
|
|
749
|
-
task_error = cast(
|
|
749
|
+
task_error = cast(ExceptionGroup | None, task_error)
|
|
750
750
|
if task_error is not None:
|
|
751
751
|
err = task_error
|
|
752
752
|
# Mark pending task as failed
|
dara/core/internal/utils.py
CHANGED
|
@@ -20,7 +20,7 @@ from __future__ import annotations
|
|
|
20
20
|
import asyncio
|
|
21
21
|
import inspect
|
|
22
22
|
import os
|
|
23
|
-
from collections.abc import Awaitable, Coroutine, Sequence
|
|
23
|
+
from collections.abc import Awaitable, Callable, Coroutine, Sequence
|
|
24
24
|
from functools import wraps
|
|
25
25
|
from importlib import import_module
|
|
26
26
|
from importlib.util import find_spec
|
|
@@ -28,14 +28,8 @@ from types import ModuleType
|
|
|
28
28
|
from typing import (
|
|
29
29
|
TYPE_CHECKING,
|
|
30
30
|
Any,
|
|
31
|
-
Callable,
|
|
32
|
-
Dict,
|
|
33
31
|
Literal,
|
|
34
|
-
Optional,
|
|
35
|
-
Tuple,
|
|
36
|
-
Type,
|
|
37
32
|
TypeVar,
|
|
38
|
-
Union,
|
|
39
33
|
)
|
|
40
34
|
|
|
41
35
|
import anyio
|
|
@@ -55,10 +49,10 @@ if TYPE_CHECKING:
|
|
|
55
49
|
|
|
56
50
|
# CacheScope stores as a key an user if cache is set to users, a session_id if cache is sessions or is set to 'global' otherwise
|
|
57
51
|
# The value is a cache_key, for example the cache key used to store derived variable results to the store
|
|
58
|
-
CacheScope =
|
|
52
|
+
CacheScope = Literal['global'] | str
|
|
59
53
|
|
|
60
54
|
|
|
61
|
-
def get_cache_scope(cache_type:
|
|
55
|
+
def get_cache_scope(cache_type: CacheType | None) -> CacheScope:
|
|
62
56
|
"""
|
|
63
57
|
Helper to resolve the cache scope
|
|
64
58
|
|
|
@@ -80,7 +74,7 @@ def get_cache_scope(cache_type: Optional[CacheType]) -> CacheScope:
|
|
|
80
74
|
return 'global'
|
|
81
75
|
|
|
82
76
|
|
|
83
|
-
async def run_user_handler(handler: Callable, args:
|
|
77
|
+
async def run_user_handler(handler: Callable, args: Sequence | None = None, kwargs: dict | None = None):
|
|
84
78
|
"""
|
|
85
79
|
Run a user-defined handler function. Runs sync functions in a threadpool.
|
|
86
80
|
Handles SystemExits cleanly.
|
|
@@ -123,7 +117,7 @@ def call_async(handler: Callable[..., Coroutine], *args):
|
|
|
123
117
|
portal.call(handler, *args)
|
|
124
118
|
|
|
125
119
|
|
|
126
|
-
def import_config(config_path: str) ->
|
|
120
|
+
def import_config(config_path: str) -> tuple[ModuleType, ConfigurationBuilder]:
|
|
127
121
|
"""
|
|
128
122
|
Import Dara from specified config in format "my_package.my_module:variable_name"
|
|
129
123
|
"""
|
|
@@ -194,9 +188,9 @@ def async_dedupe(fn: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
|
|
|
194
188
|
This decorator is useful for operations that might be triggered multiple times in parallel
|
|
195
189
|
but should be executed only once to prevent redundant work or data fetches.
|
|
196
190
|
"""
|
|
197
|
-
locks:
|
|
198
|
-
results:
|
|
199
|
-
wait_counts:
|
|
191
|
+
locks: dict[tuple, anyio.Lock] = {}
|
|
192
|
+
results: dict[tuple, Any] = {}
|
|
193
|
+
wait_counts: dict[tuple, int] = {}
|
|
200
194
|
|
|
201
195
|
is_method = 'self' in inspect.signature(fn).parameters
|
|
202
196
|
|
|
@@ -243,7 +237,7 @@ def resolve_exception_group(error: Any):
|
|
|
243
237
|
return error
|
|
244
238
|
|
|
245
239
|
|
|
246
|
-
def exception_group_contains(err_type:
|
|
240
|
+
def exception_group_contains(err_type: type[BaseException], group: BaseExceptionGroup) -> bool:
|
|
247
241
|
"""
|
|
248
242
|
Check if an ExceptionGroup contains an error of a given type, recursively
|
|
249
243
|
|