dara-core 1.17.6__py3-none-any.whl → 1.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. dara/core/__init__.py +2 -0
  2. dara/core/actions.py +1 -2
  3. dara/core/auth/basic.py +9 -9
  4. dara/core/auth/routes.py +5 -5
  5. dara/core/auth/utils.py +4 -4
  6. dara/core/base_definitions.py +15 -22
  7. dara/core/cli.py +8 -7
  8. dara/core/configuration.py +5 -2
  9. dara/core/css.py +1 -2
  10. dara/core/data_utils.py +2 -2
  11. dara/core/defaults.py +4 -7
  12. dara/core/definitions.py +6 -9
  13. dara/core/http.py +7 -3
  14. dara/core/interactivity/actions.py +28 -30
  15. dara/core/interactivity/any_data_variable.py +6 -5
  16. dara/core/interactivity/any_variable.py +4 -7
  17. dara/core/interactivity/data_variable.py +1 -1
  18. dara/core/interactivity/derived_data_variable.py +7 -6
  19. dara/core/interactivity/derived_variable.py +93 -33
  20. dara/core/interactivity/filtering.py +19 -27
  21. dara/core/interactivity/plain_variable.py +3 -2
  22. dara/core/interactivity/switch_variable.py +4 -4
  23. dara/core/internal/cache_store/base_impl.py +2 -1
  24. dara/core/internal/cache_store/cache_store.py +17 -5
  25. dara/core/internal/cache_store/keep_all.py +4 -1
  26. dara/core/internal/cache_store/lru.py +5 -1
  27. dara/core/internal/cache_store/ttl.py +4 -1
  28. dara/core/internal/cgroup.py +1 -1
  29. dara/core/internal/dependency_resolution.py +46 -10
  30. dara/core/internal/devtools.py +2 -2
  31. dara/core/internal/download.py +4 -3
  32. dara/core/internal/encoder_registry.py +7 -7
  33. dara/core/internal/execute_action.py +4 -10
  34. dara/core/internal/hashing.py +1 -3
  35. dara/core/internal/import_discovery.py +3 -4
  36. dara/core/internal/normalization.py +9 -13
  37. dara/core/internal/pandas_utils.py +3 -3
  38. dara/core/internal/pool/task_pool.py +16 -10
  39. dara/core/internal/pool/utils.py +5 -7
  40. dara/core/internal/pool/worker.py +3 -2
  41. dara/core/internal/port_utils.py +1 -1
  42. dara/core/internal/registries.py +9 -4
  43. dara/core/internal/registry.py +3 -1
  44. dara/core/internal/registry_lookup.py +7 -3
  45. dara/core/internal/routing.py +77 -44
  46. dara/core/internal/scheduler.py +13 -8
  47. dara/core/internal/settings.py +2 -2
  48. dara/core/internal/tasks.py +8 -14
  49. dara/core/internal/utils.py +11 -10
  50. dara/core/internal/websocket.py +18 -19
  51. dara/core/js_tooling/js_utils.py +23 -24
  52. dara/core/logging.py +3 -6
  53. dara/core/main.py +14 -11
  54. dara/core/metrics/cache.py +1 -1
  55. dara/core/metrics/utils.py +3 -3
  56. dara/core/persistence.py +1 -1
  57. dara/core/umd/dara.core.umd.js +146 -128
  58. dara/core/visual/components/__init__.py +2 -2
  59. dara/core/visual/components/fallback.py +3 -3
  60. dara/core/visual/css/__init__.py +30 -31
  61. dara/core/visual/dynamic_component.py +10 -11
  62. dara/core/visual/progress_updater.py +4 -3
  63. {dara_core-1.17.6.dist-info → dara_core-1.18.0.dist-info}/METADATA +10 -10
  64. dara_core-1.18.0.dist-info/RECORD +114 -0
  65. dara_core-1.17.6.dist-info/RECORD +0 -114
  66. {dara_core-1.17.6.dist-info → dara_core-1.18.0.dist-info}/LICENSE +0 -0
  67. {dara_core-1.17.6.dist-info → dara_core-1.18.0.dist-info}/WHEEL +0 -0
  68. {dara_core-1.17.6.dist-info → dara_core-1.18.0.dist-info}/entry_points.txt +0 -0
@@ -88,12 +88,13 @@ class LRUCache(CacheStoreImpl[LruCachePolicy]):
88
88
  self.cache.pop(key, None)
89
89
  return node.value
90
90
 
91
- async def get(self, key: str, unpin: bool = False) -> Optional[Any]:
91
+ async def get(self, key: str, unpin: bool = False, raise_for_missing: bool = False) -> Optional[Any]:
92
92
  """
93
93
  Retrieve a value from the cache.
94
94
 
95
95
  :param key: The key of the value to retrieve.
96
96
  :param unpin: If true, the entry will be unpinned if it is pinned.
97
+ :param raise_for_missing: If true, an exception will be raised if the entry is not found
97
98
  :return: The value associated with the key, or None if the key is not in the cache.
98
99
  """
99
100
  async with self.lock:
@@ -103,6 +104,9 @@ class LRUCache(CacheStoreImpl[LruCachePolicy]):
103
104
  node.pin = False
104
105
  self._move_to_front(node)
105
106
  return node.value
107
+
108
+ if raise_for_missing:
109
+ raise KeyError(f'No cache entry found for {key}')
106
110
  return None
107
111
 
108
112
  async def set(self, key: str, value: Any, pin: bool = False) -> None:
@@ -52,12 +52,13 @@ class TTLCache(CacheStoreImpl[TTLCachePolicy]):
52
52
  _, key = heapq.heappop(self.expiration_heap)
53
53
  self.unpinned_cache.pop(key, None)
54
54
 
55
- async def get(self, key: str, unpin: bool = False) -> Any:
55
+ async def get(self, key: str, unpin: bool = False, raise_for_missing: bool = False) -> Any:
56
56
  """
57
57
  Retrieve a value from the cache.
58
58
 
59
59
  :param key: The key of the value to retrieve.
60
60
  :param unpin: If true, the entry will be unpinned if it is pinned.
61
+ :param raise_for_missing: If true, an exception will be raised if the entry is not found
61
62
  :return: The value associated with the key, or None if the key is not in the cache.
62
63
  """
63
64
  async with self.lock:
@@ -75,6 +76,8 @@ class TTLCache(CacheStoreImpl[TTLCachePolicy]):
75
76
  elif key in self.unpinned_cache:
76
77
  return self.unpinned_cache[key].value
77
78
 
79
+ if raise_for_missing:
80
+ raise KeyError(f'No cache entry found for {key}')
78
81
  return None
79
82
 
80
83
  async def set(self, key: str, value: Any, pin: bool = False) -> None:
@@ -22,7 +22,7 @@ import sys
22
22
 
23
23
  from dara.core.logging import dev_logger, eng_logger
24
24
 
25
- CGROUP_V2_INDICATOR_PATH = '/sys/fs/cgroup/cgroup.controllers' # Used to determine whether we're using cgroupv2
25
+ CGROUP_V2_INDICATOR_PATH = '/sys/fs/cgroup/cgroup.controllers' # Used to determine whether we're using cgroupv2
26
26
 
27
27
  CGROUP_V1_MEM_PATH = '/sys/fs/cgroup/memory/memory.limit_in_bytes'
28
28
  CGROUP_V2_MEM_PATH = '/sys/fs/cgroup/memory.max'
@@ -30,20 +30,18 @@ from dara.core.logging import dev_logger
30
30
 
31
31
 
32
32
  class ResolvedDerivedVariable(TypedDict):
33
- deps: List[int]
34
33
  type: Literal['derived']
35
34
  uid: str
36
35
  values: List[Any]
37
- force: bool
36
+ force_key: Optional[str]
38
37
 
39
38
 
40
39
  class ResolvedDerivedDataVariable(TypedDict):
41
- deps: List[int]
42
40
  type: Literal['derived-data']
43
41
  uid: str
44
42
  values: List[Any]
45
43
  filters: Optional[Union[FilterQuery, dict]]
46
- force: bool
44
+ force_key: Optional[str]
47
45
 
48
46
 
49
47
  class ResolvedDataVariable(TypedDict):
@@ -64,7 +62,9 @@ def is_resolved_derived_variable(obj: Any) -> TypeGuard[ResolvedDerivedVariable]
64
62
  return isinstance(obj, dict) and 'uid' in obj and obj.get('type') == 'derived'
65
63
 
66
64
 
67
- def is_resolved_derived_data_variable(obj: Any) -> TypeGuard[ResolvedDerivedDataVariable]:
65
+ def is_resolved_derived_data_variable(
66
+ obj: Any,
67
+ ) -> TypeGuard[ResolvedDerivedDataVariable]:
68
68
  return isinstance(obj, dict) and 'uid' in obj and obj.get('type') == 'derived-data'
69
69
 
70
70
 
@@ -76,9 +76,29 @@ def is_resolved_switch_variable(obj: Any) -> TypeGuard[ResolvedSwitchVariable]:
76
76
  return isinstance(obj, dict) and 'uid' in obj and obj.get('type') == 'switch'
77
77
 
78
78
 
79
+ def clean_force_key(value: Any) -> Any:
80
+ """
81
+ Clean an argument to a value to remove force keys
82
+ """
83
+ if value is None:
84
+ return value
85
+
86
+ if isinstance(value, dict):
87
+ # Remove force key from the value
88
+ value.pop('force_key', None)
89
+ return {k: clean_force_key(v) for k, v in value.items()}
90
+ if isinstance(value, list):
91
+ return [clean_force_key(v) for v in value]
92
+ return value
93
+
94
+
79
95
  async def resolve_dependency(
80
96
  entry: Union[
81
- ResolvedDerivedDataVariable, ResolvedDataVariable, ResolvedDerivedVariable, ResolvedSwitchVariable, Any
97
+ ResolvedDerivedDataVariable,
98
+ ResolvedDataVariable,
99
+ ResolvedDerivedVariable,
100
+ ResolvedSwitchVariable,
101
+ Any,
82
102
  ],
83
103
  store: CacheStore,
84
104
  task_mgr: TaskManager,
@@ -127,13 +147,21 @@ async def _resolve_derived_data_var(entry: ResolvedDerivedDataVariable, store: C
127
147
  input_values: List[Any] = entry.get('values', [])
128
148
 
129
149
  result = await DerivedDataVariable.resolve_value(
130
- data_var, dv_var, store, task_mgr, input_values, entry.get('force', False), entry.get('filters', None)
150
+ data_entry=data_var,
151
+ dv_entry=dv_var,
152
+ store=store,
153
+ task_mgr=task_mgr,
154
+ args=input_values,
155
+ filters=entry.get('filters', None),
156
+ force_key=entry.get('force_key'),
131
157
  )
132
158
  return remove_index(result)
133
159
 
134
160
 
135
161
  async def _resolve_derived_var(
136
- derived_variable_entry: ResolvedDerivedVariable, store: CacheStore, task_mgr: TaskManager
162
+ derived_variable_entry: ResolvedDerivedVariable,
163
+ store: CacheStore,
164
+ task_mgr: TaskManager,
137
165
  ):
138
166
  """
139
167
  Resolve a derived variable from the registry and get it's new value based on the dynamic variable mapping passed
@@ -149,7 +177,11 @@ async def _resolve_derived_var(
149
177
  var = await registry_mgr.get(derived_variable_registry, str(derived_variable_entry.get('uid')))
150
178
  input_values: List[Any] = derived_variable_entry.get('values', [])
151
179
  result = await DerivedVariable.get_value(
152
- var, store, task_mgr, input_values, derived_variable_entry.get('force', False)
180
+ var_entry=var,
181
+ store=store,
182
+ task_mgr=task_mgr,
183
+ args=input_values,
184
+ force_key=derived_variable_entry.get('force_key'),
153
185
  )
154
186
  return result['value']
155
187
 
@@ -228,7 +260,11 @@ def _evaluate_condition(condition: dict) -> bool:
228
260
  raise ValueError(f'Unknown condition operator: {operator}')
229
261
 
230
262
 
231
- async def _resolve_switch_var(switch_variable_entry: ResolvedSwitchVariable, store: CacheStore, task_mgr: TaskManager):
263
+ async def _resolve_switch_var(
264
+ switch_variable_entry: ResolvedSwitchVariable,
265
+ store: CacheStore,
266
+ task_mgr: TaskManager,
267
+ ):
232
268
  """
233
269
  Resolve a switch variable by evaluating its constituent parts and returning the appropriate value.
234
270
 
@@ -36,7 +36,7 @@ def print_stacktrace():
36
36
  trc = 'Traceback (most recent call last):\n'
37
37
  stackstr = trc + ''.join(traceback.format_list(stack))
38
38
  if exc is not None:
39
- stackstr += ' ' + traceback.format_exc().lstrip(trc) # pylint:disable=bad-str-strip-call
39
+ stackstr += ' ' + traceback.format_exc().lstrip(trc)
40
40
  else:
41
41
  stackstr += ' Exception'
42
42
 
@@ -52,7 +52,7 @@ def handle_system_exit(error_msg: str):
52
52
  try:
53
53
  yield
54
54
  except SystemExit as e:
55
- raise InterruptedError(error_msg).with_traceback(e.__traceback__)
55
+ raise InterruptedError(error_msg) from e
56
56
 
57
57
 
58
58
  def get_error_for_channel() -> dict:
@@ -18,7 +18,8 @@ limitations under the License.
18
18
  from __future__ import annotations
19
19
 
20
20
  import os
21
- from typing import Awaitable, Callable, Optional, Tuple
21
+ from collections.abc import Awaitable
22
+ from typing import Callable, Optional, Tuple
22
23
  from uuid import uuid4
23
24
 
24
25
  import anyio
@@ -37,13 +38,13 @@ class DownloadDataEntry(BaseModel):
37
38
  file_path: str
38
39
  cleanup_file: bool
39
40
  identity_name: Optional[str] = None
40
- download: Callable[['DownloadDataEntry'], Awaitable[Tuple[anyio.AsyncFile, Callable[..., Awaitable]]]]
41
+ download: Callable[[DownloadDataEntry], Awaitable[Tuple[anyio.AsyncFile, Callable[..., Awaitable]]]]
41
42
  """Handler for getting the file from the entry"""
42
43
 
43
44
 
44
45
  DownloadRegistryEntry = CachedRegistryEntry(
45
46
  uid='_dara_download', cache=Cache.Policy.TTL(ttl=60 * 10)
46
- ) # expire the codes after 10 minutes
47
+ ) # expire the codes after 10 minutes
47
48
 
48
49
 
49
50
  async def download(data_entry: DownloadDataEntry) -> Tuple[anyio.AsyncFile, Callable[..., Awaitable]]:
@@ -14,13 +14,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
- # pylint: disable=unnecessary-lambda
17
+
18
+ from collections.abc import MutableMapping
18
19
  from inspect import Parameter, isclass
19
20
  from typing import (
20
21
  Any,
21
22
  Callable,
22
23
  Dict,
23
- MutableMapping,
24
24
  Optional,
25
25
  Type,
26
26
  Union,
@@ -99,11 +99,11 @@ def _tuple_key_deserialize(d):
99
99
  if isinstance(key, str) and key.startswith('__tuple__'):
100
100
  key_list = []
101
101
  for each in key[10:-1].split(', '):
102
- if (each.startswith("'") and each.endswith(("'"))) or (each.startswith('"') and each.endswith(('"'))):
102
+ if (each.startswith("'") and each.endswith("'")) or (each.startswith('"') and each.endswith('"')):
103
103
  key_list.append(each[1:-1])
104
104
  else:
105
105
  key_list.append(each)
106
- encoded_key = encoded_key = tuple(key_list)
106
+ encoded_key = tuple(key_list)
107
107
  else:
108
108
  encoded_key = key
109
109
 
@@ -112,7 +112,7 @@ def _tuple_key_deserialize(d):
112
112
  return encoded_dict
113
113
 
114
114
 
115
- def _df_deserialize(x): # pylint: disable=inconsistent-return-statements
115
+ def _df_deserialize(x):
116
116
  """
117
117
  A function to deserialize data into a DataFrame
118
118
 
@@ -240,14 +240,14 @@ def deserialize(value: Any, typ: Optional[Type]):
240
240
  return value
241
241
 
242
242
  # Already matches type
243
- if type(value) == typ:
243
+ if type(value) is typ:
244
244
  return value
245
245
 
246
246
  # Handle Optional[foo] / Union[foo, None] -> call deserialize(value, foo)
247
247
  if get_origin(typ) == Union:
248
248
  args = get_args(typ)
249
249
  if len(args) == 2 and type(None) in args:
250
- not_none_arg = args[0] if args[0] != type(None) else args[1]
250
+ not_none_arg = args[0] if args[0] is not type(None) else args[1]
251
251
  return deserialize(value, not_none_arg)
252
252
 
253
253
  try:
@@ -14,11 +14,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
+
17
18
  from __future__ import annotations
18
19
 
19
20
  import asyncio
21
+ from collections.abc import Mapping
20
22
  from contextvars import ContextVar
21
- from typing import Any, Callable, Mapping, Optional, Union
23
+ from typing import Any, Callable, Optional, Union
22
24
 
23
25
  import anyio
24
26
 
@@ -30,11 +32,7 @@ from dara.core.interactivity.actions import (
30
32
  ActionImpl,
31
33
  )
32
34
  from dara.core.internal.cache_store import CacheStore
33
- from dara.core.internal.dependency_resolution import (
34
- is_resolved_derived_data_variable,
35
- is_resolved_derived_variable,
36
- resolve_dependency,
37
- )
35
+ from dara.core.internal.dependency_resolution import resolve_dependency
38
36
  from dara.core.internal.encoder_registry import deserialize
39
37
  from dara.core.internal.tasks import MetaTask, TaskManager
40
38
  from dara.core.internal.utils import run_user_handler
@@ -146,10 +144,6 @@ async def execute_action(
146
144
  annotations = action.__annotations__
147
145
 
148
146
  for key, value in values.items():
149
- # Override `force` property to be false
150
- if is_resolved_derived_variable(value) or is_resolved_derived_data_variable(value):
151
- value['force'] = False
152
-
153
147
  typ = annotations.get(key)
154
148
  val = await resolve_dependency(value, store, task_mgr)
155
149
  resolved_kwargs[key] = deserialize(val, typ)
@@ -31,8 +31,6 @@ def hash_object(obj: Union[BaseModel, dict, None]):
31
31
  if isinstance(obj, BaseModel):
32
32
  obj = obj.model_dump()
33
33
 
34
- filter_hash = hashlib.sha1(
35
- usedforsecurity=False
36
- ) # nosec B303 # we don't use this for security purposes just as a cache key
34
+ filter_hash = hashlib.sha1(usedforsecurity=False) # nosec B303 # we don't use this for security purposes just as a cache key
37
35
  filter_hash.update(json.dumps(obj or {}, sort_keys=True).encode())
38
36
  return filter_hash.hexdigest()
@@ -93,10 +93,9 @@ def run_discovery(
93
93
  # If module root is passed through, use it
94
94
  if 'module_root' in kwargs:
95
95
  root = kwargs.get('module_root')
96
- else:
97
- # Try to infer from module_name
98
- if module_name is not None:
99
- root = module_name.split('.')[0]
96
+ # Try to infer from module_name
97
+ elif module_name is not None:
98
+ root = module_name.split('.')[0]
100
99
 
101
100
  for k, v in global_symbols.items():
102
101
  # Ignore already encountered functions
@@ -15,11 +15,11 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
+ from collections.abc import Mapping
18
19
  from typing import (
19
20
  Any,
20
21
  Generic,
21
22
  List,
22
- Mapping,
23
23
  Optional,
24
24
  Tuple,
25
25
  TypeVar,
@@ -48,7 +48,7 @@ class Placeholder(TypedDict):
48
48
  Placeholder object 'Referrable' objects are replaced with
49
49
  """
50
50
 
51
- __ref: str # pylint: disable=unused-private-member
51
+ __ref: str
52
52
 
53
53
 
54
54
  class Referrable(TypedDict):
@@ -56,7 +56,7 @@ class Referrable(TypedDict):
56
56
  Describes an object which can be replaced by a Placeholder.
57
57
  """
58
58
 
59
- __typename: str # pylint: disable=unused-private-member
59
+ __typename: str
60
60
  uid: str
61
61
 
62
62
 
@@ -133,13 +133,11 @@ def _loop(iterable: JsonLike):
133
133
 
134
134
 
135
135
  @overload
136
- def normalize(obj: Mapping, check_root: bool = True) -> Tuple[Mapping, Mapping]:
137
- ...
136
+ def normalize(obj: Mapping, check_root: bool = True) -> Tuple[Mapping, Mapping]: ...
138
137
 
139
138
 
140
139
  @overload
141
- def normalize(obj: List, check_root: bool = True) -> Tuple[List, Mapping]:
142
- ...
140
+ def normalize(obj: List, check_root: bool = True) -> Tuple[List, Mapping]: ...
143
141
 
144
142
 
145
143
  def normalize(obj: JsonLike, check_root: bool = True) -> Tuple[JsonLike, Mapping]:
@@ -169,7 +167,7 @@ def normalize(obj: JsonLike, check_root: bool = True) -> Tuple[JsonLike, Mapping
169
167
  for key, value in _loop(obj):
170
168
  # For iterables, recursively call normalize
171
169
  if isinstance(value, (dict, list)):
172
- _normalized, _lookup = normalize(value) # type: ignore
170
+ _normalized, _lookup = normalize(value) # type: ignore
173
171
  output[key] = _normalized # type: ignore
174
172
  lookup.update(_lookup)
175
173
  else:
@@ -180,13 +178,11 @@ def normalize(obj: JsonLike, check_root: bool = True) -> Tuple[JsonLike, Mapping
180
178
 
181
179
 
182
180
  @overload
183
- def denormalize(normalized_obj: Mapping, lookup: Mapping) -> Mapping:
184
- ...
181
+ def denormalize(normalized_obj: Mapping, lookup: Mapping) -> Mapping: ...
185
182
 
186
183
 
187
184
  @overload
188
- def denormalize(normalized_obj: List, lookup: Mapping) -> List:
189
- ...
185
+ def denormalize(normalized_obj: List, lookup: Mapping) -> List: ...
190
186
 
191
187
 
192
188
  def denormalize(normalized_obj: JsonLike, lookup: Mapping) -> Optional[JsonLike]:
@@ -206,7 +202,7 @@ def denormalize(normalized_obj: JsonLike, lookup: Mapping) -> Optional[JsonLike]
206
202
  # Whole object is a placeholder
207
203
  if _is_placeholder(normalized_obj):
208
204
  ref = normalized_obj['__ref']
209
- referrable = lookup[ref] if ref in lookup else None
205
+ referrable = lookup.get(ref, None)
210
206
 
211
207
  if isinstance(referrable, (list, dict)):
212
208
  return denormalize(referrable, lookup)
@@ -15,7 +15,7 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
- from typing import Optional, TypeVar
18
+ from typing import Optional, TypeVar, cast
19
19
 
20
20
  from pandas import DataFrame, MultiIndex
21
21
 
@@ -31,7 +31,7 @@ def append_index(df: Optional[DataFrame]) -> Optional[DataFrame]:
31
31
 
32
32
  if INDEX not in df.columns:
33
33
  new_df = df.copy()
34
- new_df.insert(0, INDEX, range(0, len(df.index)))
34
+ new_df.insert(0, INDEX, range(0, len(df.index))) # type: ignore
35
35
  return new_df
36
36
 
37
37
  return df
@@ -47,7 +47,7 @@ def remove_index(value: value_type) -> value_type:
47
47
  Otherwise return same value untouched.
48
48
  """
49
49
  if isinstance(value, DataFrame):
50
- return value.drop(columns=['__index__'], inplace=False, errors='ignore')
50
+ return cast(value_type, value.drop(columns=['__index__'], inplace=False, errors='ignore'))
51
51
 
52
52
  return value
53
53
 
@@ -16,10 +16,11 @@ limitations under the License.
16
16
  """
17
17
 
18
18
  import atexit
19
+ from collections.abc import Coroutine
19
20
  from contextlib import contextmanager
20
21
  from datetime import datetime
21
22
  from multiprocessing import active_children
22
- from typing import Any, Callable, Coroutine, Dict, Optional, cast
23
+ from typing import Any, Callable, Dict, Optional, Union, cast
23
24
 
24
25
  import anyio
25
26
  from anyio.abc import TaskGroup
@@ -102,16 +103,18 @@ class TaskPool:
102
103
  try:
103
104
  await wait_while(
104
105
  lambda: self.status != PoolStatus.RUNNING
105
- or not len(self.workers) == self.desired_workers
106
+ or len(self.workers) != self.desired_workers
106
107
  or not all(w.status == WorkerStatus.IDLE for w in self.workers.values()),
107
108
  timeout=timeout,
108
109
  )
109
- except TimeoutError:
110
- raise RuntimeError('Failed to start pool')
110
+ except TimeoutError as e:
111
+ raise RuntimeError('Failed to start pool') from e
111
112
  else:
112
113
  raise RuntimeError('Pool already started')
113
114
 
114
- def submit(self, task_uid: str, function_name: str, args: tuple = (), kwargs: dict = {}) -> TaskDefinition:
115
+ def submit(
116
+ self, task_uid: str, function_name: str, args: Union[tuple, None] = None, kwargs: Union[dict, None] = None
117
+ ) -> TaskDefinition:
115
118
  """
116
119
  Submit a new task to the pool
117
120
 
@@ -120,6 +123,10 @@ class TaskPool:
120
123
  :param args: list of arguments to pass to the function
121
124
  :param kwargs: dict of kwargs to pass to the function
122
125
  """
126
+ if args is None:
127
+ args = ()
128
+ if kwargs is None:
129
+ kwargs = {}
123
130
  self._check_pool_state()
124
131
 
125
132
  # Create a task definition to keep track of its progress
@@ -463,9 +470,8 @@ class TaskPool:
463
470
  )
464
471
  elif is_log(worker_msg):
465
472
  dev_logger.info(f'Task: {worker_msg["task_uid"]}', {'logs': worker_msg['log']})
466
- elif is_progress(worker_msg):
467
- if worker_msg['task_uid'] in self._progress_subscribers:
468
- await self._progress_subscribers[worker_msg['task_uid']](worker_msg['progress'], worker_msg['message'])
473
+ elif is_progress(worker_msg) and worker_msg['task_uid'] in self._progress_subscribers:
474
+ await self._progress_subscribers[worker_msg['task_uid']](worker_msg['progress'], worker_msg['message'])
469
475
 
470
476
  async def _wait_queue_depletion(self, timeout: Optional[float] = None):
471
477
  """
@@ -478,8 +484,8 @@ class TaskPool:
478
484
  condition=lambda: self.status in (PoolStatus.CLOSED, PoolStatus.RUNNING) and len(self.tasks) > 0,
479
485
  timeout=timeout,
480
486
  )
481
- except TimeoutError:
482
- raise TimeoutError('Tasks are still being executed')
487
+ except TimeoutError as e:
488
+ raise TimeoutError('Tasks are still being executed') from e
483
489
 
484
490
  async def _core_loop(self):
485
491
  """
@@ -89,9 +89,7 @@ def read_from_shared_memory(pointer: SharedMemoryPointer) -> Any:
89
89
  data = shared_mem.buf[:data_size]
90
90
 
91
91
  # Unpickle and deepcopy
92
- decoded_payload_shared = pickle.loads(
93
- shared_mem.buf
94
- ) # nosec B301 # we trust the shared memory pointer passed by the pool
92
+ decoded_payload_shared = pickle.loads(shared_mem.buf) # nosec B301 # we trust the shared memory pointer passed by the pool
95
93
  decoded_payload = copy.deepcopy(decoded_payload_shared)
96
94
 
97
95
  # Cleanup
@@ -141,8 +139,8 @@ async def stop_process_async(process: BaseProcess, timeout: float = 3):
141
139
  try:
142
140
  os.kill(process.pid, signal.SIGKILL)
143
141
  await wait_while(process.is_alive, timeout)
144
- except OSError:
145
- raise RuntimeError(f'Unable to terminate subprocess with PID {process.pid}')
142
+ except OSError as e:
143
+ raise RuntimeError(f'Unable to terminate subprocess with PID {process.pid}') from e
146
144
 
147
145
  # If it's still alive raise an exception
148
146
  if process.is_alive():
@@ -171,8 +169,8 @@ def stop_process(process: BaseProcess, timeout: float = 3):
171
169
  try:
172
170
  os.kill(process.pid, signal.SIGKILL)
173
171
  process.join(timeout)
174
- except OSError:
175
- raise RuntimeError(f'Unable to terminate subprocess with PID {process.pid}')
172
+ except OSError as e:
173
+ raise RuntimeError(f'Unable to terminate subprocess with PID {process.pid}') from e
176
174
 
177
175
  # If it's still alive raise an exception
178
176
  if process.is_alive():
@@ -57,7 +57,8 @@ class StdoutLogger:
57
57
  self.channel.worker_api.log(self.task_uid, msg)
58
58
 
59
59
  def flush(self):
60
- sys.__stdout__.flush()
60
+ if sys.__stdout__:
61
+ sys.__stdout__.flush()
61
62
 
62
63
 
63
64
  def execute_function(func: Callable, args: tuple, kwargs: dict):
@@ -164,7 +165,7 @@ def worker_loop(worker_params: WorkerParameters, channel: Channel):
164
165
 
165
166
  # Redirect logs via the channel
166
167
  stdout_logger = StdoutLogger(task_uid, channel)
167
- sys.stdout = stdout_logger # type: ignore
168
+ sys.stdout = stdout_logger # type: ignore
168
169
 
169
170
  try:
170
171
  payload_pointer = task['payload']
@@ -27,7 +27,7 @@ def is_available(host: str, port: int) -> bool:
27
27
  """
28
28
  try:
29
29
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
30
- sock.settimeout(2.0) # timeout in case port is blocked
30
+ sock.settimeout(2.0) # timeout in case port is blocked
31
31
  return sock.connect_ex((host, port)) != 0
32
32
  except BaseException:
33
33
  return False
@@ -15,8 +15,9 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
+ from collections.abc import Mapping
18
19
  from datetime import datetime
19
- from typing import Any, Callable, Mapping, Set
20
+ from typing import Any, Callable, Set
20
21
 
21
22
  from dara.core.auth import BaseAuthConfig
22
23
  from dara.core.base_definitions import ActionDef, ActionResolverDef, UploadResolverDef
@@ -31,15 +32,16 @@ from dara.core.interactivity.derived_variable import (
31
32
  DerivedVariableRegistryEntry,
32
33
  LatestValueRegistryEntry,
33
34
  )
35
+ from dara.core.internal.download import DownloadDataEntry
34
36
  from dara.core.internal.registry import Registry, RegistryType
35
37
  from dara.core.internal.websocket import CustomClientMessagePayload
36
38
  from dara.core.persistence import BackendStoreEntry
37
39
 
38
- action_def_registry = Registry[ActionDef](RegistryType.ACTION_DEF, CORE_ACTIONS) # all registered actions
39
- action_registry = Registry[ActionResolverDef](RegistryType.ACTION) # functions for actions requiring backend calls
40
+ action_def_registry = Registry[ActionDef](RegistryType.ACTION_DEF, CORE_ACTIONS) # all registered actions
41
+ action_registry = Registry[ActionResolverDef](RegistryType.ACTION) # functions for actions requiring backend calls
40
42
  upload_resolver_registry = Registry[UploadResolverDef](
41
43
  RegistryType.UPLOAD_RESOLVER
42
- ) # functions for upload resolvers requiring backend calls
44
+ ) # functions for upload resolvers requiring backend calls
43
45
  component_registry = Registry[ComponentTypeAnnotation](RegistryType.COMPONENTS, CORE_COMPONENTS)
44
46
  config_registry = Registry[EndpointConfiguration](RegistryType.ENDPOINT_CONFIG)
45
47
  data_variable_registry = Registry[DataVariableRegistryEntry](RegistryType.DATA_VARIABLE, allow_duplicates=False)
@@ -69,3 +71,6 @@ custom_ws_handlers_registry = Registry[Callable[[str, CustomClientMessagePayload
69
71
 
70
72
  backend_store_registry = Registry[BackendStoreEntry](RegistryType.BACKEND_STORE, allow_duplicates=False)
71
73
  """map of store uid -> store instance"""
74
+
75
+ download_code_registry = Registry[DownloadDataEntry](RegistryType.DOWNLOAD_CODE, allow_duplicates=False)
76
+ """map of download codes -> download data entry, used only to allow overriding download code behaviour via RegistryLookup"""
@@ -16,8 +16,9 @@ limitations under the License.
16
16
  """
17
17
 
18
18
  import copy
19
+ from collections.abc import MutableMapping
19
20
  from enum import Enum
20
- from typing import Generic, MutableMapping, Optional, TypeVar
21
+ from typing import Generic, Optional, TypeVar
21
22
 
22
23
  from dara.core.metrics import CACHE_METRICS_TRACKER, total_size
23
24
 
@@ -43,6 +44,7 @@ class RegistryType(str, Enum):
43
44
  PENDING_TOKENS = 'Pending tokens'
44
45
  CUSTOM_WS_HANDLERS = 'Custom WS handlers'
45
46
  BACKEND_STORE = 'Backend Store'
47
+ DOWNLOAD_CODE = 'Download Code'
46
48
 
47
49
 
48
50
  class Registry(Generic[T]):