dara-core 1.17.5__py3-none-any.whl → 1.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dara/core/__init__.py +2 -0
- dara/core/actions.py +1 -2
- dara/core/auth/basic.py +9 -9
- dara/core/auth/routes.py +5 -5
- dara/core/auth/utils.py +4 -4
- dara/core/base_definitions.py +15 -22
- dara/core/cli.py +8 -7
- dara/core/configuration.py +5 -2
- dara/core/css.py +1 -2
- dara/core/data_utils.py +2 -2
- dara/core/defaults.py +4 -7
- dara/core/definitions.py +6 -9
- dara/core/http.py +7 -3
- dara/core/interactivity/actions.py +28 -30
- dara/core/interactivity/any_data_variable.py +6 -5
- dara/core/interactivity/any_variable.py +4 -7
- dara/core/interactivity/data_variable.py +1 -1
- dara/core/interactivity/derived_data_variable.py +7 -6
- dara/core/interactivity/derived_variable.py +93 -33
- dara/core/interactivity/filtering.py +19 -27
- dara/core/interactivity/plain_variable.py +3 -2
- dara/core/interactivity/switch_variable.py +4 -4
- dara/core/internal/cache_store/base_impl.py +2 -1
- dara/core/internal/cache_store/cache_store.py +17 -5
- dara/core/internal/cache_store/keep_all.py +4 -1
- dara/core/internal/cache_store/lru.py +5 -1
- dara/core/internal/cache_store/ttl.py +4 -1
- dara/core/internal/cgroup.py +1 -1
- dara/core/internal/dependency_resolution.py +46 -10
- dara/core/internal/devtools.py +2 -2
- dara/core/internal/download.py +4 -3
- dara/core/internal/encoder_registry.py +7 -7
- dara/core/internal/execute_action.py +4 -10
- dara/core/internal/hashing.py +1 -3
- dara/core/internal/import_discovery.py +3 -4
- dara/core/internal/normalization.py +9 -13
- dara/core/internal/pandas_utils.py +3 -3
- dara/core/internal/pool/task_pool.py +16 -10
- dara/core/internal/pool/utils.py +5 -7
- dara/core/internal/pool/worker.py +3 -2
- dara/core/internal/port_utils.py +1 -1
- dara/core/internal/registries.py +9 -4
- dara/core/internal/registry.py +3 -1
- dara/core/internal/registry_lookup.py +7 -3
- dara/core/internal/routing.py +77 -44
- dara/core/internal/scheduler.py +13 -8
- dara/core/internal/settings.py +2 -2
- dara/core/internal/tasks.py +8 -14
- dara/core/internal/utils.py +11 -10
- dara/core/internal/websocket.py +18 -19
- dara/core/js_tooling/js_utils.py +23 -24
- dara/core/logging.py +3 -6
- dara/core/main.py +14 -11
- dara/core/metrics/cache.py +1 -1
- dara/core/metrics/utils.py +3 -3
- dara/core/persistence.py +1 -1
- dara/core/umd/dara.core.umd.js +149 -128
- dara/core/visual/components/__init__.py +2 -2
- dara/core/visual/components/fallback.py +3 -3
- dara/core/visual/css/__init__.py +30 -31
- dara/core/visual/dynamic_component.py +10 -11
- dara/core/visual/progress_updater.py +4 -3
- {dara_core-1.17.5.dist-info → dara_core-1.18.0.dist-info}/METADATA +10 -10
- dara_core-1.18.0.dist-info/RECORD +114 -0
- dara_core-1.17.5.dist-info/RECORD +0 -114
- {dara_core-1.17.5.dist-info → dara_core-1.18.0.dist-info}/LICENSE +0 -0
- {dara_core-1.17.5.dist-info → dara_core-1.18.0.dist-info}/WHEEL +0 -0
- {dara_core-1.17.5.dist-info → dara_core-1.18.0.dist-info}/entry_points.txt +0 -0
|
@@ -88,12 +88,13 @@ class LRUCache(CacheStoreImpl[LruCachePolicy]):
|
|
|
88
88
|
self.cache.pop(key, None)
|
|
89
89
|
return node.value
|
|
90
90
|
|
|
91
|
-
async def get(self, key: str, unpin: bool = False) -> Optional[Any]:
|
|
91
|
+
async def get(self, key: str, unpin: bool = False, raise_for_missing: bool = False) -> Optional[Any]:
|
|
92
92
|
"""
|
|
93
93
|
Retrieve a value from the cache.
|
|
94
94
|
|
|
95
95
|
:param key: The key of the value to retrieve.
|
|
96
96
|
:param unpin: If true, the entry will be unpinned if it is pinned.
|
|
97
|
+
:param raise_for_missing: If true, an exception will be raised if the entry is not found
|
|
97
98
|
:return: The value associated with the key, or None if the key is not in the cache.
|
|
98
99
|
"""
|
|
99
100
|
async with self.lock:
|
|
@@ -103,6 +104,9 @@ class LRUCache(CacheStoreImpl[LruCachePolicy]):
|
|
|
103
104
|
node.pin = False
|
|
104
105
|
self._move_to_front(node)
|
|
105
106
|
return node.value
|
|
107
|
+
|
|
108
|
+
if raise_for_missing:
|
|
109
|
+
raise KeyError(f'No cache entry found for {key}')
|
|
106
110
|
return None
|
|
107
111
|
|
|
108
112
|
async def set(self, key: str, value: Any, pin: bool = False) -> None:
|
|
@@ -52,12 +52,13 @@ class TTLCache(CacheStoreImpl[TTLCachePolicy]):
|
|
|
52
52
|
_, key = heapq.heappop(self.expiration_heap)
|
|
53
53
|
self.unpinned_cache.pop(key, None)
|
|
54
54
|
|
|
55
|
-
async def get(self, key: str, unpin: bool = False) -> Any:
|
|
55
|
+
async def get(self, key: str, unpin: bool = False, raise_for_missing: bool = False) -> Any:
|
|
56
56
|
"""
|
|
57
57
|
Retrieve a value from the cache.
|
|
58
58
|
|
|
59
59
|
:param key: The key of the value to retrieve.
|
|
60
60
|
:param unpin: If true, the entry will be unpinned if it is pinned.
|
|
61
|
+
:param raise_for_missing: If true, an exception will be raised if the entry is not found
|
|
61
62
|
:return: The value associated with the key, or None if the key is not in the cache.
|
|
62
63
|
"""
|
|
63
64
|
async with self.lock:
|
|
@@ -75,6 +76,8 @@ class TTLCache(CacheStoreImpl[TTLCachePolicy]):
|
|
|
75
76
|
elif key in self.unpinned_cache:
|
|
76
77
|
return self.unpinned_cache[key].value
|
|
77
78
|
|
|
79
|
+
if raise_for_missing:
|
|
80
|
+
raise KeyError(f'No cache entry found for {key}')
|
|
78
81
|
return None
|
|
79
82
|
|
|
80
83
|
async def set(self, key: str, value: Any, pin: bool = False) -> None:
|
dara/core/internal/cgroup.py
CHANGED
|
@@ -22,7 +22,7 @@ import sys
|
|
|
22
22
|
|
|
23
23
|
from dara.core.logging import dev_logger, eng_logger
|
|
24
24
|
|
|
25
|
-
CGROUP_V2_INDICATOR_PATH = '/sys/fs/cgroup/cgroup.controllers'
|
|
25
|
+
CGROUP_V2_INDICATOR_PATH = '/sys/fs/cgroup/cgroup.controllers' # Used to determine whether we're using cgroupv2
|
|
26
26
|
|
|
27
27
|
CGROUP_V1_MEM_PATH = '/sys/fs/cgroup/memory/memory.limit_in_bytes'
|
|
28
28
|
CGROUP_V2_MEM_PATH = '/sys/fs/cgroup/memory.max'
|
|
@@ -30,20 +30,18 @@ from dara.core.logging import dev_logger
|
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class ResolvedDerivedVariable(TypedDict):
|
|
33
|
-
deps: List[int]
|
|
34
33
|
type: Literal['derived']
|
|
35
34
|
uid: str
|
|
36
35
|
values: List[Any]
|
|
37
|
-
|
|
36
|
+
force_key: Optional[str]
|
|
38
37
|
|
|
39
38
|
|
|
40
39
|
class ResolvedDerivedDataVariable(TypedDict):
|
|
41
|
-
deps: List[int]
|
|
42
40
|
type: Literal['derived-data']
|
|
43
41
|
uid: str
|
|
44
42
|
values: List[Any]
|
|
45
43
|
filters: Optional[Union[FilterQuery, dict]]
|
|
46
|
-
|
|
44
|
+
force_key: Optional[str]
|
|
47
45
|
|
|
48
46
|
|
|
49
47
|
class ResolvedDataVariable(TypedDict):
|
|
@@ -64,7 +62,9 @@ def is_resolved_derived_variable(obj: Any) -> TypeGuard[ResolvedDerivedVariable]
|
|
|
64
62
|
return isinstance(obj, dict) and 'uid' in obj and obj.get('type') == 'derived'
|
|
65
63
|
|
|
66
64
|
|
|
67
|
-
def is_resolved_derived_data_variable(
|
|
65
|
+
def is_resolved_derived_data_variable(
|
|
66
|
+
obj: Any,
|
|
67
|
+
) -> TypeGuard[ResolvedDerivedDataVariable]:
|
|
68
68
|
return isinstance(obj, dict) and 'uid' in obj and obj.get('type') == 'derived-data'
|
|
69
69
|
|
|
70
70
|
|
|
@@ -76,9 +76,29 @@ def is_resolved_switch_variable(obj: Any) -> TypeGuard[ResolvedSwitchVariable]:
|
|
|
76
76
|
return isinstance(obj, dict) and 'uid' in obj and obj.get('type') == 'switch'
|
|
77
77
|
|
|
78
78
|
|
|
79
|
+
def clean_force_key(value: Any) -> Any:
|
|
80
|
+
"""
|
|
81
|
+
Clean an argument to a value to remove force keys
|
|
82
|
+
"""
|
|
83
|
+
if value is None:
|
|
84
|
+
return value
|
|
85
|
+
|
|
86
|
+
if isinstance(value, dict):
|
|
87
|
+
# Remove force key from the value
|
|
88
|
+
value.pop('force_key', None)
|
|
89
|
+
return {k: clean_force_key(v) for k, v in value.items()}
|
|
90
|
+
if isinstance(value, list):
|
|
91
|
+
return [clean_force_key(v) for v in value]
|
|
92
|
+
return value
|
|
93
|
+
|
|
94
|
+
|
|
79
95
|
async def resolve_dependency(
|
|
80
96
|
entry: Union[
|
|
81
|
-
ResolvedDerivedDataVariable,
|
|
97
|
+
ResolvedDerivedDataVariable,
|
|
98
|
+
ResolvedDataVariable,
|
|
99
|
+
ResolvedDerivedVariable,
|
|
100
|
+
ResolvedSwitchVariable,
|
|
101
|
+
Any,
|
|
82
102
|
],
|
|
83
103
|
store: CacheStore,
|
|
84
104
|
task_mgr: TaskManager,
|
|
@@ -127,13 +147,21 @@ async def _resolve_derived_data_var(entry: ResolvedDerivedDataVariable, store: C
|
|
|
127
147
|
input_values: List[Any] = entry.get('values', [])
|
|
128
148
|
|
|
129
149
|
result = await DerivedDataVariable.resolve_value(
|
|
130
|
-
data_var,
|
|
150
|
+
data_entry=data_var,
|
|
151
|
+
dv_entry=dv_var,
|
|
152
|
+
store=store,
|
|
153
|
+
task_mgr=task_mgr,
|
|
154
|
+
args=input_values,
|
|
155
|
+
filters=entry.get('filters', None),
|
|
156
|
+
force_key=entry.get('force_key'),
|
|
131
157
|
)
|
|
132
158
|
return remove_index(result)
|
|
133
159
|
|
|
134
160
|
|
|
135
161
|
async def _resolve_derived_var(
|
|
136
|
-
derived_variable_entry: ResolvedDerivedVariable,
|
|
162
|
+
derived_variable_entry: ResolvedDerivedVariable,
|
|
163
|
+
store: CacheStore,
|
|
164
|
+
task_mgr: TaskManager,
|
|
137
165
|
):
|
|
138
166
|
"""
|
|
139
167
|
Resolve a derived variable from the registry and get it's new value based on the dynamic variable mapping passed
|
|
@@ -149,7 +177,11 @@ async def _resolve_derived_var(
|
|
|
149
177
|
var = await registry_mgr.get(derived_variable_registry, str(derived_variable_entry.get('uid')))
|
|
150
178
|
input_values: List[Any] = derived_variable_entry.get('values', [])
|
|
151
179
|
result = await DerivedVariable.get_value(
|
|
152
|
-
var,
|
|
180
|
+
var_entry=var,
|
|
181
|
+
store=store,
|
|
182
|
+
task_mgr=task_mgr,
|
|
183
|
+
args=input_values,
|
|
184
|
+
force_key=derived_variable_entry.get('force_key'),
|
|
153
185
|
)
|
|
154
186
|
return result['value']
|
|
155
187
|
|
|
@@ -228,7 +260,11 @@ def _evaluate_condition(condition: dict) -> bool:
|
|
|
228
260
|
raise ValueError(f'Unknown condition operator: {operator}')
|
|
229
261
|
|
|
230
262
|
|
|
231
|
-
async def _resolve_switch_var(
|
|
263
|
+
async def _resolve_switch_var(
|
|
264
|
+
switch_variable_entry: ResolvedSwitchVariable,
|
|
265
|
+
store: CacheStore,
|
|
266
|
+
task_mgr: TaskManager,
|
|
267
|
+
):
|
|
232
268
|
"""
|
|
233
269
|
Resolve a switch variable by evaluating its constituent parts and returning the appropriate value.
|
|
234
270
|
|
dara/core/internal/devtools.py
CHANGED
|
@@ -36,7 +36,7 @@ def print_stacktrace():
|
|
|
36
36
|
trc = 'Traceback (most recent call last):\n'
|
|
37
37
|
stackstr = trc + ''.join(traceback.format_list(stack))
|
|
38
38
|
if exc is not None:
|
|
39
|
-
stackstr += ' ' + traceback.format_exc().lstrip(trc)
|
|
39
|
+
stackstr += ' ' + traceback.format_exc().lstrip(trc)
|
|
40
40
|
else:
|
|
41
41
|
stackstr += ' Exception'
|
|
42
42
|
|
|
@@ -52,7 +52,7 @@ def handle_system_exit(error_msg: str):
|
|
|
52
52
|
try:
|
|
53
53
|
yield
|
|
54
54
|
except SystemExit as e:
|
|
55
|
-
raise InterruptedError(error_msg)
|
|
55
|
+
raise InterruptedError(error_msg) from e
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
def get_error_for_channel() -> dict:
|
dara/core/internal/download.py
CHANGED
|
@@ -18,7 +18,8 @@ limitations under the License.
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
20
|
import os
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import Awaitable
|
|
22
|
+
from typing import Callable, Optional, Tuple
|
|
22
23
|
from uuid import uuid4
|
|
23
24
|
|
|
24
25
|
import anyio
|
|
@@ -37,13 +38,13 @@ class DownloadDataEntry(BaseModel):
|
|
|
37
38
|
file_path: str
|
|
38
39
|
cleanup_file: bool
|
|
39
40
|
identity_name: Optional[str] = None
|
|
40
|
-
download: Callable[[
|
|
41
|
+
download: Callable[[DownloadDataEntry], Awaitable[Tuple[anyio.AsyncFile, Callable[..., Awaitable]]]]
|
|
41
42
|
"""Handler for getting the file from the entry"""
|
|
42
43
|
|
|
43
44
|
|
|
44
45
|
DownloadRegistryEntry = CachedRegistryEntry(
|
|
45
46
|
uid='_dara_download', cache=Cache.Policy.TTL(ttl=60 * 10)
|
|
46
|
-
)
|
|
47
|
+
) # expire the codes after 10 minutes
|
|
47
48
|
|
|
48
49
|
|
|
49
50
|
async def download(data_entry: DownloadDataEntry) -> Tuple[anyio.AsyncFile, Callable[..., Awaitable]]:
|
|
@@ -14,13 +14,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
14
14
|
See the License for the specific language governing permissions and
|
|
15
15
|
limitations under the License.
|
|
16
16
|
"""
|
|
17
|
-
|
|
17
|
+
|
|
18
|
+
from collections.abc import MutableMapping
|
|
18
19
|
from inspect import Parameter, isclass
|
|
19
20
|
from typing import (
|
|
20
21
|
Any,
|
|
21
22
|
Callable,
|
|
22
23
|
Dict,
|
|
23
|
-
MutableMapping,
|
|
24
24
|
Optional,
|
|
25
25
|
Type,
|
|
26
26
|
Union,
|
|
@@ -99,11 +99,11 @@ def _tuple_key_deserialize(d):
|
|
|
99
99
|
if isinstance(key, str) and key.startswith('__tuple__'):
|
|
100
100
|
key_list = []
|
|
101
101
|
for each in key[10:-1].split(', '):
|
|
102
|
-
if (each.startswith("'") and each.endswith(
|
|
102
|
+
if (each.startswith("'") and each.endswith("'")) or (each.startswith('"') and each.endswith('"')):
|
|
103
103
|
key_list.append(each[1:-1])
|
|
104
104
|
else:
|
|
105
105
|
key_list.append(each)
|
|
106
|
-
encoded_key =
|
|
106
|
+
encoded_key = tuple(key_list)
|
|
107
107
|
else:
|
|
108
108
|
encoded_key = key
|
|
109
109
|
|
|
@@ -112,7 +112,7 @@ def _tuple_key_deserialize(d):
|
|
|
112
112
|
return encoded_dict
|
|
113
113
|
|
|
114
114
|
|
|
115
|
-
def _df_deserialize(x):
|
|
115
|
+
def _df_deserialize(x):
|
|
116
116
|
"""
|
|
117
117
|
A function to deserialize data into a DataFrame
|
|
118
118
|
|
|
@@ -240,14 +240,14 @@ def deserialize(value: Any, typ: Optional[Type]):
|
|
|
240
240
|
return value
|
|
241
241
|
|
|
242
242
|
# Already matches type
|
|
243
|
-
if type(value)
|
|
243
|
+
if type(value) is typ:
|
|
244
244
|
return value
|
|
245
245
|
|
|
246
246
|
# Handle Optional[foo] / Union[foo, None] -> call deserialize(value, foo)
|
|
247
247
|
if get_origin(typ) == Union:
|
|
248
248
|
args = get_args(typ)
|
|
249
249
|
if len(args) == 2 and type(None) in args:
|
|
250
|
-
not_none_arg = args[0] if args[0]
|
|
250
|
+
not_none_arg = args[0] if args[0] is not type(None) else args[1]
|
|
251
251
|
return deserialize(value, not_none_arg)
|
|
252
252
|
|
|
253
253
|
try:
|
|
@@ -14,11 +14,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
14
14
|
See the License for the specific language governing permissions and
|
|
15
15
|
limitations under the License.
|
|
16
16
|
"""
|
|
17
|
+
|
|
17
18
|
from __future__ import annotations
|
|
18
19
|
|
|
19
20
|
import asyncio
|
|
21
|
+
from collections.abc import Mapping
|
|
20
22
|
from contextvars import ContextVar
|
|
21
|
-
from typing import Any, Callable,
|
|
23
|
+
from typing import Any, Callable, Optional, Union
|
|
22
24
|
|
|
23
25
|
import anyio
|
|
24
26
|
|
|
@@ -30,11 +32,7 @@ from dara.core.interactivity.actions import (
|
|
|
30
32
|
ActionImpl,
|
|
31
33
|
)
|
|
32
34
|
from dara.core.internal.cache_store import CacheStore
|
|
33
|
-
from dara.core.internal.dependency_resolution import
|
|
34
|
-
is_resolved_derived_data_variable,
|
|
35
|
-
is_resolved_derived_variable,
|
|
36
|
-
resolve_dependency,
|
|
37
|
-
)
|
|
35
|
+
from dara.core.internal.dependency_resolution import resolve_dependency
|
|
38
36
|
from dara.core.internal.encoder_registry import deserialize
|
|
39
37
|
from dara.core.internal.tasks import MetaTask, TaskManager
|
|
40
38
|
from dara.core.internal.utils import run_user_handler
|
|
@@ -146,10 +144,6 @@ async def execute_action(
|
|
|
146
144
|
annotations = action.__annotations__
|
|
147
145
|
|
|
148
146
|
for key, value in values.items():
|
|
149
|
-
# Override `force` property to be false
|
|
150
|
-
if is_resolved_derived_variable(value) or is_resolved_derived_data_variable(value):
|
|
151
|
-
value['force'] = False
|
|
152
|
-
|
|
153
147
|
typ = annotations.get(key)
|
|
154
148
|
val = await resolve_dependency(value, store, task_mgr)
|
|
155
149
|
resolved_kwargs[key] = deserialize(val, typ)
|
dara/core/internal/hashing.py
CHANGED
|
@@ -31,8 +31,6 @@ def hash_object(obj: Union[BaseModel, dict, None]):
|
|
|
31
31
|
if isinstance(obj, BaseModel):
|
|
32
32
|
obj = obj.model_dump()
|
|
33
33
|
|
|
34
|
-
filter_hash = hashlib.sha1(
|
|
35
|
-
usedforsecurity=False
|
|
36
|
-
) # nosec B303 # we don't use this for security purposes just as a cache key
|
|
34
|
+
filter_hash = hashlib.sha1(usedforsecurity=False) # nosec B303 # we don't use this for security purposes just as a cache key
|
|
37
35
|
filter_hash.update(json.dumps(obj or {}, sort_keys=True).encode())
|
|
38
36
|
return filter_hash.hexdigest()
|
|
@@ -93,10 +93,9 @@ def run_discovery(
|
|
|
93
93
|
# If module root is passed through, use it
|
|
94
94
|
if 'module_root' in kwargs:
|
|
95
95
|
root = kwargs.get('module_root')
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
root = module_name.split('.')[0]
|
|
96
|
+
# Try to infer from module_name
|
|
97
|
+
elif module_name is not None:
|
|
98
|
+
root = module_name.split('.')[0]
|
|
100
99
|
|
|
101
100
|
for k, v in global_symbols.items():
|
|
102
101
|
# Ignore already encountered functions
|
|
@@ -15,11 +15,11 @@ See the License for the specific language governing permissions and
|
|
|
15
15
|
limitations under the License.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
+
from collections.abc import Mapping
|
|
18
19
|
from typing import (
|
|
19
20
|
Any,
|
|
20
21
|
Generic,
|
|
21
22
|
List,
|
|
22
|
-
Mapping,
|
|
23
23
|
Optional,
|
|
24
24
|
Tuple,
|
|
25
25
|
TypeVar,
|
|
@@ -48,7 +48,7 @@ class Placeholder(TypedDict):
|
|
|
48
48
|
Placeholder object 'Referrable' objects are replaced with
|
|
49
49
|
"""
|
|
50
50
|
|
|
51
|
-
__ref: str
|
|
51
|
+
__ref: str
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
class Referrable(TypedDict):
|
|
@@ -56,7 +56,7 @@ class Referrable(TypedDict):
|
|
|
56
56
|
Describes an object which can be replaced by a Placeholder.
|
|
57
57
|
"""
|
|
58
58
|
|
|
59
|
-
__typename: str
|
|
59
|
+
__typename: str
|
|
60
60
|
uid: str
|
|
61
61
|
|
|
62
62
|
|
|
@@ -133,13 +133,11 @@ def _loop(iterable: JsonLike):
|
|
|
133
133
|
|
|
134
134
|
|
|
135
135
|
@overload
|
|
136
|
-
def normalize(obj: Mapping, check_root: bool = True) -> Tuple[Mapping, Mapping]:
|
|
137
|
-
...
|
|
136
|
+
def normalize(obj: Mapping, check_root: bool = True) -> Tuple[Mapping, Mapping]: ...
|
|
138
137
|
|
|
139
138
|
|
|
140
139
|
@overload
|
|
141
|
-
def normalize(obj: List, check_root: bool = True) -> Tuple[List, Mapping]:
|
|
142
|
-
...
|
|
140
|
+
def normalize(obj: List, check_root: bool = True) -> Tuple[List, Mapping]: ...
|
|
143
141
|
|
|
144
142
|
|
|
145
143
|
def normalize(obj: JsonLike, check_root: bool = True) -> Tuple[JsonLike, Mapping]:
|
|
@@ -169,7 +167,7 @@ def normalize(obj: JsonLike, check_root: bool = True) -> Tuple[JsonLike, Mapping
|
|
|
169
167
|
for key, value in _loop(obj):
|
|
170
168
|
# For iterables, recursively call normalize
|
|
171
169
|
if isinstance(value, (dict, list)):
|
|
172
|
-
_normalized, _lookup = normalize(value)
|
|
170
|
+
_normalized, _lookup = normalize(value) # type: ignore
|
|
173
171
|
output[key] = _normalized # type: ignore
|
|
174
172
|
lookup.update(_lookup)
|
|
175
173
|
else:
|
|
@@ -180,13 +178,11 @@ def normalize(obj: JsonLike, check_root: bool = True) -> Tuple[JsonLike, Mapping
|
|
|
180
178
|
|
|
181
179
|
|
|
182
180
|
@overload
|
|
183
|
-
def denormalize(normalized_obj: Mapping, lookup: Mapping) -> Mapping:
|
|
184
|
-
...
|
|
181
|
+
def denormalize(normalized_obj: Mapping, lookup: Mapping) -> Mapping: ...
|
|
185
182
|
|
|
186
183
|
|
|
187
184
|
@overload
|
|
188
|
-
def denormalize(normalized_obj: List, lookup: Mapping) -> List:
|
|
189
|
-
...
|
|
185
|
+
def denormalize(normalized_obj: List, lookup: Mapping) -> List: ...
|
|
190
186
|
|
|
191
187
|
|
|
192
188
|
def denormalize(normalized_obj: JsonLike, lookup: Mapping) -> Optional[JsonLike]:
|
|
@@ -206,7 +202,7 @@ def denormalize(normalized_obj: JsonLike, lookup: Mapping) -> Optional[JsonLike]
|
|
|
206
202
|
# Whole object is a placeholder
|
|
207
203
|
if _is_placeholder(normalized_obj):
|
|
208
204
|
ref = normalized_obj['__ref']
|
|
209
|
-
referrable = lookup
|
|
205
|
+
referrable = lookup.get(ref, None)
|
|
210
206
|
|
|
211
207
|
if isinstance(referrable, (list, dict)):
|
|
212
208
|
return denormalize(referrable, lookup)
|
|
@@ -15,7 +15,7 @@ See the License for the specific language governing permissions and
|
|
|
15
15
|
limitations under the License.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
-
from typing import Optional, TypeVar
|
|
18
|
+
from typing import Optional, TypeVar, cast
|
|
19
19
|
|
|
20
20
|
from pandas import DataFrame, MultiIndex
|
|
21
21
|
|
|
@@ -31,7 +31,7 @@ def append_index(df: Optional[DataFrame]) -> Optional[DataFrame]:
|
|
|
31
31
|
|
|
32
32
|
if INDEX not in df.columns:
|
|
33
33
|
new_df = df.copy()
|
|
34
|
-
new_df.insert(0, INDEX, range(0, len(df.index)))
|
|
34
|
+
new_df.insert(0, INDEX, range(0, len(df.index))) # type: ignore
|
|
35
35
|
return new_df
|
|
36
36
|
|
|
37
37
|
return df
|
|
@@ -47,7 +47,7 @@ def remove_index(value: value_type) -> value_type:
|
|
|
47
47
|
Otherwise return same value untouched.
|
|
48
48
|
"""
|
|
49
49
|
if isinstance(value, DataFrame):
|
|
50
|
-
return value.drop(columns=['__index__'], inplace=False, errors='ignore')
|
|
50
|
+
return cast(value_type, value.drop(columns=['__index__'], inplace=False, errors='ignore'))
|
|
51
51
|
|
|
52
52
|
return value
|
|
53
53
|
|
|
@@ -16,10 +16,11 @@ limitations under the License.
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
import atexit
|
|
19
|
+
from collections.abc import Coroutine
|
|
19
20
|
from contextlib import contextmanager
|
|
20
21
|
from datetime import datetime
|
|
21
22
|
from multiprocessing import active_children
|
|
22
|
-
from typing import Any, Callable,
|
|
23
|
+
from typing import Any, Callable, Dict, Optional, Union, cast
|
|
23
24
|
|
|
24
25
|
import anyio
|
|
25
26
|
from anyio.abc import TaskGroup
|
|
@@ -102,16 +103,18 @@ class TaskPool:
|
|
|
102
103
|
try:
|
|
103
104
|
await wait_while(
|
|
104
105
|
lambda: self.status != PoolStatus.RUNNING
|
|
105
|
-
or
|
|
106
|
+
or len(self.workers) != self.desired_workers
|
|
106
107
|
or not all(w.status == WorkerStatus.IDLE for w in self.workers.values()),
|
|
107
108
|
timeout=timeout,
|
|
108
109
|
)
|
|
109
|
-
except TimeoutError:
|
|
110
|
-
raise RuntimeError('Failed to start pool')
|
|
110
|
+
except TimeoutError as e:
|
|
111
|
+
raise RuntimeError('Failed to start pool') from e
|
|
111
112
|
else:
|
|
112
113
|
raise RuntimeError('Pool already started')
|
|
113
114
|
|
|
114
|
-
def submit(
|
|
115
|
+
def submit(
|
|
116
|
+
self, task_uid: str, function_name: str, args: Union[tuple, None] = None, kwargs: Union[dict, None] = None
|
|
117
|
+
) -> TaskDefinition:
|
|
115
118
|
"""
|
|
116
119
|
Submit a new task to the pool
|
|
117
120
|
|
|
@@ -120,6 +123,10 @@ class TaskPool:
|
|
|
120
123
|
:param args: list of arguments to pass to the function
|
|
121
124
|
:param kwargs: dict of kwargs to pass to the function
|
|
122
125
|
"""
|
|
126
|
+
if args is None:
|
|
127
|
+
args = ()
|
|
128
|
+
if kwargs is None:
|
|
129
|
+
kwargs = {}
|
|
123
130
|
self._check_pool_state()
|
|
124
131
|
|
|
125
132
|
# Create a task definition to keep track of its progress
|
|
@@ -463,9 +470,8 @@ class TaskPool:
|
|
|
463
470
|
)
|
|
464
471
|
elif is_log(worker_msg):
|
|
465
472
|
dev_logger.info(f'Task: {worker_msg["task_uid"]}', {'logs': worker_msg['log']})
|
|
466
|
-
elif is_progress(worker_msg):
|
|
467
|
-
|
|
468
|
-
await self._progress_subscribers[worker_msg['task_uid']](worker_msg['progress'], worker_msg['message'])
|
|
473
|
+
elif is_progress(worker_msg) and worker_msg['task_uid'] in self._progress_subscribers:
|
|
474
|
+
await self._progress_subscribers[worker_msg['task_uid']](worker_msg['progress'], worker_msg['message'])
|
|
469
475
|
|
|
470
476
|
async def _wait_queue_depletion(self, timeout: Optional[float] = None):
|
|
471
477
|
"""
|
|
@@ -478,8 +484,8 @@ class TaskPool:
|
|
|
478
484
|
condition=lambda: self.status in (PoolStatus.CLOSED, PoolStatus.RUNNING) and len(self.tasks) > 0,
|
|
479
485
|
timeout=timeout,
|
|
480
486
|
)
|
|
481
|
-
except TimeoutError:
|
|
482
|
-
raise TimeoutError('Tasks are still being executed')
|
|
487
|
+
except TimeoutError as e:
|
|
488
|
+
raise TimeoutError('Tasks are still being executed') from e
|
|
483
489
|
|
|
484
490
|
async def _core_loop(self):
|
|
485
491
|
"""
|
dara/core/internal/pool/utils.py
CHANGED
|
@@ -89,9 +89,7 @@ def read_from_shared_memory(pointer: SharedMemoryPointer) -> Any:
|
|
|
89
89
|
data = shared_mem.buf[:data_size]
|
|
90
90
|
|
|
91
91
|
# Unpickle and deepcopy
|
|
92
|
-
decoded_payload_shared = pickle.loads(
|
|
93
|
-
shared_mem.buf
|
|
94
|
-
) # nosec B301 # we trust the shared memory pointer passed by the pool
|
|
92
|
+
decoded_payload_shared = pickle.loads(shared_mem.buf) # nosec B301 # we trust the shared memory pointer passed by the pool
|
|
95
93
|
decoded_payload = copy.deepcopy(decoded_payload_shared)
|
|
96
94
|
|
|
97
95
|
# Cleanup
|
|
@@ -141,8 +139,8 @@ async def stop_process_async(process: BaseProcess, timeout: float = 3):
|
|
|
141
139
|
try:
|
|
142
140
|
os.kill(process.pid, signal.SIGKILL)
|
|
143
141
|
await wait_while(process.is_alive, timeout)
|
|
144
|
-
except OSError:
|
|
145
|
-
raise RuntimeError(f'Unable to terminate subprocess with PID {process.pid}')
|
|
142
|
+
except OSError as e:
|
|
143
|
+
raise RuntimeError(f'Unable to terminate subprocess with PID {process.pid}') from e
|
|
146
144
|
|
|
147
145
|
# If it's still alive raise an exception
|
|
148
146
|
if process.is_alive():
|
|
@@ -171,8 +169,8 @@ def stop_process(process: BaseProcess, timeout: float = 3):
|
|
|
171
169
|
try:
|
|
172
170
|
os.kill(process.pid, signal.SIGKILL)
|
|
173
171
|
process.join(timeout)
|
|
174
|
-
except OSError:
|
|
175
|
-
raise RuntimeError(f'Unable to terminate subprocess with PID {process.pid}')
|
|
172
|
+
except OSError as e:
|
|
173
|
+
raise RuntimeError(f'Unable to terminate subprocess with PID {process.pid}') from e
|
|
176
174
|
|
|
177
175
|
# If it's still alive raise an exception
|
|
178
176
|
if process.is_alive():
|
|
@@ -57,7 +57,8 @@ class StdoutLogger:
|
|
|
57
57
|
self.channel.worker_api.log(self.task_uid, msg)
|
|
58
58
|
|
|
59
59
|
def flush(self):
|
|
60
|
-
sys.__stdout__
|
|
60
|
+
if sys.__stdout__:
|
|
61
|
+
sys.__stdout__.flush()
|
|
61
62
|
|
|
62
63
|
|
|
63
64
|
def execute_function(func: Callable, args: tuple, kwargs: dict):
|
|
@@ -164,7 +165,7 @@ def worker_loop(worker_params: WorkerParameters, channel: Channel):
|
|
|
164
165
|
|
|
165
166
|
# Redirect logs via the channel
|
|
166
167
|
stdout_logger = StdoutLogger(task_uid, channel)
|
|
167
|
-
sys.stdout = stdout_logger
|
|
168
|
+
sys.stdout = stdout_logger # type: ignore
|
|
168
169
|
|
|
169
170
|
try:
|
|
170
171
|
payload_pointer = task['payload']
|
dara/core/internal/port_utils.py
CHANGED
|
@@ -27,7 +27,7 @@ def is_available(host: str, port: int) -> bool:
|
|
|
27
27
|
"""
|
|
28
28
|
try:
|
|
29
29
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
|
30
|
-
sock.settimeout(2.0)
|
|
30
|
+
sock.settimeout(2.0) # timeout in case port is blocked
|
|
31
31
|
return sock.connect_ex((host, port)) != 0
|
|
32
32
|
except BaseException:
|
|
33
33
|
return False
|
dara/core/internal/registries.py
CHANGED
|
@@ -15,8 +15,9 @@ See the License for the specific language governing permissions and
|
|
|
15
15
|
limitations under the License.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
+
from collections.abc import Mapping
|
|
18
19
|
from datetime import datetime
|
|
19
|
-
from typing import Any, Callable,
|
|
20
|
+
from typing import Any, Callable, Set
|
|
20
21
|
|
|
21
22
|
from dara.core.auth import BaseAuthConfig
|
|
22
23
|
from dara.core.base_definitions import ActionDef, ActionResolverDef, UploadResolverDef
|
|
@@ -31,15 +32,16 @@ from dara.core.interactivity.derived_variable import (
|
|
|
31
32
|
DerivedVariableRegistryEntry,
|
|
32
33
|
LatestValueRegistryEntry,
|
|
33
34
|
)
|
|
35
|
+
from dara.core.internal.download import DownloadDataEntry
|
|
34
36
|
from dara.core.internal.registry import Registry, RegistryType
|
|
35
37
|
from dara.core.internal.websocket import CustomClientMessagePayload
|
|
36
38
|
from dara.core.persistence import BackendStoreEntry
|
|
37
39
|
|
|
38
|
-
action_def_registry = Registry[ActionDef](RegistryType.ACTION_DEF, CORE_ACTIONS)
|
|
39
|
-
action_registry = Registry[ActionResolverDef](RegistryType.ACTION)
|
|
40
|
+
action_def_registry = Registry[ActionDef](RegistryType.ACTION_DEF, CORE_ACTIONS) # all registered actions
|
|
41
|
+
action_registry = Registry[ActionResolverDef](RegistryType.ACTION) # functions for actions requiring backend calls
|
|
40
42
|
upload_resolver_registry = Registry[UploadResolverDef](
|
|
41
43
|
RegistryType.UPLOAD_RESOLVER
|
|
42
|
-
)
|
|
44
|
+
) # functions for upload resolvers requiring backend calls
|
|
43
45
|
component_registry = Registry[ComponentTypeAnnotation](RegistryType.COMPONENTS, CORE_COMPONENTS)
|
|
44
46
|
config_registry = Registry[EndpointConfiguration](RegistryType.ENDPOINT_CONFIG)
|
|
45
47
|
data_variable_registry = Registry[DataVariableRegistryEntry](RegistryType.DATA_VARIABLE, allow_duplicates=False)
|
|
@@ -69,3 +71,6 @@ custom_ws_handlers_registry = Registry[Callable[[str, CustomClientMessagePayload
|
|
|
69
71
|
|
|
70
72
|
backend_store_registry = Registry[BackendStoreEntry](RegistryType.BACKEND_STORE, allow_duplicates=False)
|
|
71
73
|
"""map of store uid -> store instance"""
|
|
74
|
+
|
|
75
|
+
download_code_registry = Registry[DownloadDataEntry](RegistryType.DOWNLOAD_CODE, allow_duplicates=False)
|
|
76
|
+
"""map of download codes -> download data entry, used only to allow overriding download code behaviour via RegistryLookup"""
|
dara/core/internal/registry.py
CHANGED
|
@@ -16,8 +16,9 @@ limitations under the License.
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
import copy
|
|
19
|
+
from collections.abc import MutableMapping
|
|
19
20
|
from enum import Enum
|
|
20
|
-
from typing import Generic,
|
|
21
|
+
from typing import Generic, Optional, TypeVar
|
|
21
22
|
|
|
22
23
|
from dara.core.metrics import CACHE_METRICS_TRACKER, total_size
|
|
23
24
|
|
|
@@ -43,6 +44,7 @@ class RegistryType(str, Enum):
|
|
|
43
44
|
PENDING_TOKENS = 'Pending tokens'
|
|
44
45
|
CUSTOM_WS_HANDLERS = 'Custom WS handlers'
|
|
45
46
|
BACKEND_STORE = 'Backend Store'
|
|
47
|
+
DOWNLOAD_CODE = 'Download Code'
|
|
46
48
|
|
|
47
49
|
|
|
48
50
|
class Registry(Generic[T]):
|