dara-core 1.20.1a1__py3-none-any.whl → 1.20.1a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dara/core/__init__.py +3 -0
  2. dara/core/actions.py +1 -2
  3. dara/core/auth/basic.py +22 -16
  4. dara/core/auth/definitions.py +2 -2
  5. dara/core/auth/routes.py +5 -5
  6. dara/core/auth/utils.py +5 -5
  7. dara/core/base_definitions.py +22 -64
  8. dara/core/cli.py +8 -7
  9. dara/core/configuration.py +5 -2
  10. dara/core/css.py +1 -2
  11. dara/core/data_utils.py +18 -19
  12. dara/core/defaults.py +6 -7
  13. dara/core/definitions.py +50 -19
  14. dara/core/http.py +7 -3
  15. dara/core/interactivity/__init__.py +6 -0
  16. dara/core/interactivity/actions.py +52 -50
  17. dara/core/interactivity/any_data_variable.py +7 -134
  18. dara/core/interactivity/any_variable.py +5 -8
  19. dara/core/interactivity/client_variable.py +71 -0
  20. dara/core/interactivity/data_variable.py +8 -266
  21. dara/core/interactivity/derived_data_variable.py +7 -290
  22. dara/core/interactivity/derived_variable.py +416 -176
  23. dara/core/interactivity/filtering.py +46 -27
  24. dara/core/interactivity/loop_variable.py +2 -2
  25. dara/core/interactivity/non_data_variable.py +5 -68
  26. dara/core/interactivity/plain_variable.py +89 -15
  27. dara/core/interactivity/server_variable.py +325 -0
  28. dara/core/interactivity/state_variable.py +69 -0
  29. dara/core/interactivity/switch_variable.py +19 -19
  30. dara/core/interactivity/tabular_variable.py +94 -0
  31. dara/core/interactivity/url_variable.py +10 -90
  32. dara/core/internal/cache_store/base_impl.py +2 -1
  33. dara/core/internal/cache_store/cache_store.py +22 -25
  34. dara/core/internal/cache_store/keep_all.py +4 -1
  35. dara/core/internal/cache_store/lru.py +5 -1
  36. dara/core/internal/cache_store/ttl.py +4 -1
  37. dara/core/internal/cgroup.py +1 -1
  38. dara/core/internal/dependency_resolution.py +60 -66
  39. dara/core/internal/devtools.py +12 -5
  40. dara/core/internal/download.py +13 -4
  41. dara/core/internal/encoder_registry.py +7 -7
  42. dara/core/internal/execute_action.py +13 -13
  43. dara/core/internal/hashing.py +1 -3
  44. dara/core/internal/import_discovery.py +3 -4
  45. dara/core/internal/multi_resource_lock.py +70 -0
  46. dara/core/internal/normalization.py +9 -18
  47. dara/core/internal/pandas_utils.py +107 -5
  48. dara/core/internal/pool/definitions.py +1 -1
  49. dara/core/internal/pool/task_pool.py +25 -16
  50. dara/core/internal/pool/utils.py +21 -18
  51. dara/core/internal/pool/worker.py +3 -2
  52. dara/core/internal/port_utils.py +1 -1
  53. dara/core/internal/registries.py +12 -6
  54. dara/core/internal/registry.py +4 -2
  55. dara/core/internal/registry_lookup.py +11 -5
  56. dara/core/internal/routing.py +109 -145
  57. dara/core/internal/scheduler.py +13 -8
  58. dara/core/internal/settings.py +2 -2
  59. dara/core/internal/store.py +2 -29
  60. dara/core/internal/tasks.py +379 -195
  61. dara/core/internal/utils.py +36 -13
  62. dara/core/internal/websocket.py +21 -20
  63. dara/core/js_tooling/js_utils.py +28 -26
  64. dara/core/js_tooling/templates/vite.config.template.ts +12 -3
  65. dara/core/logging.py +13 -12
  66. dara/core/main.py +14 -11
  67. dara/core/metrics/cache.py +1 -1
  68. dara/core/metrics/utils.py +3 -3
  69. dara/core/persistence.py +27 -5
  70. dara/core/umd/dara.core.umd.js +68291 -64718
  71. dara/core/visual/components/__init__.py +2 -2
  72. dara/core/visual/components/fallback.py +30 -4
  73. dara/core/visual/components/for_cmp.py +4 -1
  74. dara/core/visual/css/__init__.py +30 -31
  75. dara/core/visual/dynamic_component.py +31 -28
  76. dara/core/visual/progress_updater.py +4 -3
  77. {dara_core-1.20.1a1.dist-info → dara_core-1.20.1a3.dist-info}/METADATA +12 -11
  78. dara_core-1.20.1a3.dist-info/RECORD +119 -0
  79. dara_core-1.20.1a1.dist-info/RECORD +0 -114
  80. {dara_core-1.20.1a1.dist-info → dara_core-1.20.1a3.dist-info}/LICENSE +0 -0
  81. {dara_core-1.20.1a1.dist-info → dara_core-1.20.1a3.dist-info}/WHEEL +0 -0
  82. {dara_core-1.20.1a1.dist-info → dara_core-1.20.1a3.dist-info}/entry_points.txt +0 -0
@@ -14,13 +14,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
- # pylint: disable=unnecessary-lambda
17
+
18
+ from collections.abc import MutableMapping
18
19
  from inspect import Parameter, isclass
19
20
  from typing import (
20
21
  Any,
21
22
  Callable,
22
23
  Dict,
23
- MutableMapping,
24
24
  Optional,
25
25
  Type,
26
26
  Union,
@@ -99,11 +99,11 @@ def _tuple_key_deserialize(d):
99
99
  if isinstance(key, str) and key.startswith('__tuple__'):
100
100
  key_list = []
101
101
  for each in key[10:-1].split(', '):
102
- if (each.startswith("'") and each.endswith(("'"))) or (each.startswith('"') and each.endswith(('"'))):
102
+ if (each.startswith("'") and each.endswith("'")) or (each.startswith('"') and each.endswith('"')):
103
103
  key_list.append(each[1:-1])
104
104
  else:
105
105
  key_list.append(each)
106
- encoded_key = encoded_key = tuple(key_list)
106
+ encoded_key = tuple(key_list)
107
107
  else:
108
108
  encoded_key = key
109
109
 
@@ -112,7 +112,7 @@ def _tuple_key_deserialize(d):
112
112
  return encoded_dict
113
113
 
114
114
 
115
- def _df_deserialize(x): # pylint: disable=inconsistent-return-statements
115
+ def _df_deserialize(x):
116
116
  """
117
117
  A function to deserialize data into a DataFrame
118
118
 
@@ -240,14 +240,14 @@ def deserialize(value: Any, typ: Optional[Type]):
240
240
  return value
241
241
 
242
242
  # Already matches type
243
- if type(value) == typ:
243
+ if type(value) is typ:
244
244
  return value
245
245
 
246
246
  # Handle Optional[foo] / Union[foo, None] -> call deserialize(value, foo)
247
247
  if get_origin(typ) == Union:
248
248
  args = get_args(typ)
249
249
  if len(args) == 2 and type(None) in args:
250
- not_none_arg = args[0] if args[0] != type(None) else args[1]
250
+ not_none_arg = args[0] if args[0] is not type(None) else args[1]
251
251
  return deserialize(value, not_none_arg)
252
252
 
253
253
  try:
@@ -14,11 +14,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
+
17
18
  from __future__ import annotations
18
19
 
19
20
  import asyncio
21
+ from collections.abc import Mapping
20
22
  from contextvars import ContextVar
21
- from typing import Any, Callable, Mapping, Optional, Union
23
+ from typing import Any, Callable, Optional, Union
22
24
 
23
25
  import anyio
24
26
 
@@ -30,11 +32,7 @@ from dara.core.interactivity.actions import (
30
32
  ActionImpl,
31
33
  )
32
34
  from dara.core.internal.cache_store import CacheStore
33
- from dara.core.internal.dependency_resolution import (
34
- is_resolved_derived_data_variable,
35
- is_resolved_derived_variable,
36
- resolve_dependency,
37
- )
35
+ from dara.core.internal.dependency_resolution import resolve_dependency
38
36
  from dara.core.internal.encoder_registry import deserialize
39
37
  from dara.core.internal.tasks import MetaTask, TaskManager
40
38
  from dara.core.internal.utils import run_user_handler
@@ -145,15 +143,15 @@ async def execute_action(
145
143
  if values is not None:
146
144
  annotations = action.__annotations__
147
145
 
148
- for key, value in values.items():
149
- # Override `force` property to be false
150
- if is_resolved_derived_variable(value) or is_resolved_derived_data_variable(value):
151
- value['force'] = False
152
-
146
+ async def _resolve_kwarg(val: Any, key: str):
153
147
  typ = annotations.get(key)
154
- val = await resolve_dependency(value, store, task_mgr)
148
+ val = await resolve_dependency(val, store, task_mgr)
155
149
  resolved_kwargs[key] = deserialize(val, typ)
156
150
 
151
+ async with anyio.create_task_group() as tg:
152
+ for key, value in values.items():
153
+ tg.start_soon(_resolve_kwarg, value, key)
154
+
157
155
  # Merge resolved dynamic kwargs with static kwargs received
158
156
  resolved_kwargs = {**resolved_kwargs, **static_kwargs}
159
157
 
@@ -177,9 +175,11 @@ async def execute_action(
177
175
 
178
176
  # Note: no associated registry entry, the result are not persisted in cache
179
177
  # Return a metatask which, when all dependencies are ready, will stream the action results to the frontend
180
- return MetaTask(
178
+ meta_task = MetaTask(
181
179
  process_result=_stream_action, args=[action, ctx], kwargs=resolved_kwargs, notify_channels=notify_channels
182
180
  )
181
+ task_mgr.register_task(meta_task)
182
+ return meta_task
183
183
 
184
184
  # No tasks - run directly as an asyncio task and return the execution id
185
185
  # Originally used to use FastAPI BackgroundTasks, but these ended up causing a blocking behavior that blocked some
@@ -31,8 +31,6 @@ def hash_object(obj: Union[BaseModel, dict, None]):
31
31
  if isinstance(obj, BaseModel):
32
32
  obj = obj.model_dump()
33
33
 
34
- filter_hash = hashlib.sha1(
35
- usedforsecurity=False
36
- ) # nosec B303 # we don't use this for security purposes just as a cache key
34
+ filter_hash = hashlib.sha1(usedforsecurity=False) # nosec B303 # we don't use this for security purposes just as a cache key
37
35
  filter_hash.update(json.dumps(obj or {}, sort_keys=True).encode())
38
36
  return filter_hash.hexdigest()
@@ -93,10 +93,9 @@ def run_discovery(
93
93
  # If module root is passed through, use it
94
94
  if 'module_root' in kwargs:
95
95
  root = kwargs.get('module_root')
96
- else:
97
- # Try to infer from module_name
98
- if module_name is not None:
99
- root = module_name.split('.')[0]
96
+ # Try to infer from module_name
97
+ elif module_name is not None:
98
+ root = module_name.split('.')[0]
100
99
 
101
100
  for k, v in global_symbols.items():
102
101
  # Ignore already encountered functions
@@ -0,0 +1,70 @@
1
+ from collections import Counter
2
+ from contextlib import asynccontextmanager
3
+
4
+ import anyio
5
+
6
+
7
+ class MultiResourceLock:
8
+ """
9
+ A class that manages multiple named locks for concurrent access to shared resources.
10
+
11
+ This class allows for acquiring and releasing locks on named resources, ensuring
12
+ that only one task can access a specific resource at a time. It automatically
13
+ creates locks for new resources and cleans them up when they're no longer in use.
14
+
15
+ :reentrant:
16
+ If True a task can acquire the same resource more than once; every
17
+ subsequent acquire of an already-held lock is a no-op. If False the
18
+ second attempt raises ``RuntimeError``.
19
+ """
20
+
21
+ def __init__(self):
22
+ self._locks: dict[str, anyio.Lock] = {}
23
+ self._waiters = Counter[str]()
24
+ self._cleanup_lock = anyio.Lock()
25
+
26
+ def is_locked(self, resource_name: str) -> bool:
27
+ """
28
+ Check if a lock for the specified resource is currently held.
29
+
30
+ :param resource_name (str): The name of the resource to check.
31
+ :return: True if the lock is held, False otherwise.
32
+ """
33
+ return resource_name in self._locks and self._locks[resource_name].locked()
34
+
35
+ @asynccontextmanager
36
+ async def acquire(self, resource_name: str):
37
+ """
38
+ Acquire a lock for the specified resource.
39
+
40
+ This method is an async context manager that acquires a lock for the given
41
+ resource name. If the lock doesn't exist, it creates one. It also keeps
42
+ track of waiters to ensure proper cleanup when the resource is no longer in use.
43
+
44
+ :param resource_name (str): The name of the resource to lock.
45
+
46
+ Usage:
47
+ ```python
48
+ async with multi_lock.acquire_lock("resource_a"):
49
+ # Critical section for "resource_a"
50
+ ...
51
+ ```
52
+
53
+ Note:
54
+ The lock is automatically released when exiting the context manager.
55
+ """
56
+
57
+ async with self._cleanup_lock:
58
+ if resource_name not in self._locks:
59
+ self._locks[resource_name] = anyio.Lock()
60
+ self._waiters[resource_name] += 1
61
+
62
+ try:
63
+ async with self._locks[resource_name]:
64
+ yield
65
+ finally:
66
+ async with self._cleanup_lock:
67
+ self._waiters[resource_name] -= 1
68
+ if self._waiters[resource_name] <= 0:
69
+ del self._waiters[resource_name]
70
+ del self._locks[resource_name]
@@ -15,11 +15,11 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
+ from collections.abc import Mapping
18
19
  from typing import (
19
20
  Any,
20
21
  Generic,
21
22
  List,
22
- Mapping,
23
23
  Optional,
24
24
  Tuple,
25
25
  TypeVar,
@@ -31,7 +31,6 @@ from typing import (
31
31
  from typing_extensions import TypedDict, TypeGuard
32
32
 
33
33
  from dara.core.base_definitions import DaraBaseModel as BaseModel
34
- from dara.core.internal.hashing import hash_object
35
34
 
36
35
  JsonLike = Union[Mapping, List]
37
36
 
@@ -48,7 +47,7 @@ class Placeholder(TypedDict):
48
47
  Placeholder object 'Referrable' objects are replaced with
49
48
  """
50
49
 
51
- __ref: str # pylint: disable=unused-private-member
50
+ __ref: str
52
51
 
53
52
 
54
53
  class Referrable(TypedDict):
@@ -56,7 +55,7 @@ class Referrable(TypedDict):
56
55
  Describes an object which can be replaced by a Placeholder.
57
56
  """
58
57
 
59
- __typename: str # pylint: disable=unused-private-member
58
+ __typename: str
60
59
  uid: str
61
60
 
62
61
 
@@ -81,10 +80,6 @@ def _get_identifier(obj: Referrable) -> str:
81
80
  nested = ','.join(cast(List[str], obj['nested']))
82
81
  identifier = f'{identifier}:{nested}'
83
82
 
84
- if _is_referrable_with_filters(obj):
85
- filter_hash = hash_object(obj['filters'])
86
- identifier = f'{identifier}:{filter_hash}'
87
-
88
83
  return identifier
89
84
 
90
85
 
@@ -133,13 +128,11 @@ def _loop(iterable: JsonLike):
133
128
 
134
129
 
135
130
  @overload
136
- def normalize(obj: Mapping, check_root: bool = True) -> Tuple[Mapping, Mapping]:
137
- ...
131
+ def normalize(obj: Mapping, check_root: bool = True) -> Tuple[Mapping, Mapping]: ...
138
132
 
139
133
 
140
134
  @overload
141
- def normalize(obj: List, check_root: bool = True) -> Tuple[List, Mapping]:
142
- ...
135
+ def normalize(obj: List, check_root: bool = True) -> Tuple[List, Mapping]: ...
143
136
 
144
137
 
145
138
  def normalize(obj: JsonLike, check_root: bool = True) -> Tuple[JsonLike, Mapping]:
@@ -169,7 +162,7 @@ def normalize(obj: JsonLike, check_root: bool = True) -> Tuple[JsonLike, Mapping
169
162
  for key, value in _loop(obj):
170
163
  # For iterables, recursively call normalize
171
164
  if isinstance(value, (dict, list)):
172
- _normalized, _lookup = normalize(value) # type: ignore
165
+ _normalized, _lookup = normalize(value) # type: ignore
173
166
  output[key] = _normalized # type: ignore
174
167
  lookup.update(_lookup)
175
168
  else:
@@ -180,13 +173,11 @@ def normalize(obj: JsonLike, check_root: bool = True) -> Tuple[JsonLike, Mapping
180
173
 
181
174
 
182
175
  @overload
183
- def denormalize(normalized_obj: Mapping, lookup: Mapping) -> Mapping:
184
- ...
176
+ def denormalize(normalized_obj: Mapping, lookup: Mapping) -> Mapping: ...
185
177
 
186
178
 
187
179
  @overload
188
- def denormalize(normalized_obj: List, lookup: Mapping) -> List:
189
- ...
180
+ def denormalize(normalized_obj: List, lookup: Mapping) -> List: ...
190
181
 
191
182
 
192
183
  def denormalize(normalized_obj: JsonLike, lookup: Mapping) -> Optional[JsonLike]:
@@ -206,7 +197,7 @@ def denormalize(normalized_obj: JsonLike, lookup: Mapping) -> Optional[JsonLike]
206
197
  # Whole object is a placeholder
207
198
  if _is_placeholder(normalized_obj):
208
199
  ref = normalized_obj['__ref']
209
- referrable = lookup[ref] if ref in lookup else None
200
+ referrable = lookup.get(ref, None)
210
201
 
211
202
  if isinstance(referrable, (list, dict)):
212
203
  return denormalize(referrable, lookup)
@@ -15,13 +15,24 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
- from typing import Optional, TypeVar
18
+ import json
19
+ import uuid
20
+ from typing import Any, Literal, Optional, TypeVar, Union, cast, overload
19
21
 
20
- from pandas import DataFrame, MultiIndex
22
+ from pandas import DataFrame, MultiIndex, Series
23
+ from typing_extensions import TypedDict, TypeGuard
21
24
 
22
25
  INDEX = '__index__'
23
26
 
24
27
 
28
+ @overload
29
+ def append_index(df: DataFrame) -> DataFrame: ...
30
+
31
+
32
+ @overload
33
+ def append_index(df: None) -> None: ...
34
+
35
+
25
36
  def append_index(df: Optional[DataFrame]) -> Optional[DataFrame]:
26
37
  """
27
38
  Add a numerical index column to the dataframe
@@ -31,7 +42,7 @@ def append_index(df: Optional[DataFrame]) -> Optional[DataFrame]:
31
42
 
32
43
  if INDEX not in df.columns:
33
44
  new_df = df.copy()
34
- new_df.insert(0, INDEX, range(0, len(df.index)))
45
+ new_df.insert(0, INDEX, range(0, len(df.index))) # type: ignore
35
46
  return new_df
36
47
 
37
48
  return df
@@ -47,7 +58,7 @@ def remove_index(value: value_type) -> value_type:
47
58
  Otherwise return same value untouched.
48
59
  """
49
60
  if isinstance(value, DataFrame):
50
- return value.drop(columns=['__index__'], inplace=False, errors='ignore')
61
+ return cast(value_type, value.drop(columns=['__index__'], inplace=False, errors='ignore'))
51
62
 
52
63
  return value
53
64
 
@@ -65,6 +76,12 @@ def df_convert_to_internal(original_df: DataFrame) -> DataFrame:
65
76
  if any(isinstance(c, str) and c.startswith('__col__') for c in df.columns):
66
77
  return df
67
78
 
79
+ # Apply display transformations to the DataFrame
80
+ format_for_display(df)
81
+
82
+ # Append index to match the way we process the original DataFrame
83
+ df = cast(DataFrame, append_index(df))
84
+
68
85
  # Handle hierarchical columns: [(A, B), (A, C)] -> ['A_B', 'A_C']
69
86
  if isinstance(df.columns, MultiIndex):
70
87
  df.columns = ['_'.join(col).strip() if col[0] != INDEX else INDEX for col in df.columns.values]
@@ -90,4 +107,89 @@ def df_convert_to_internal(original_df: DataFrame) -> DataFrame:
90
107
 
91
108
 
92
109
  def df_to_json(df: DataFrame) -> str:
93
- return df_convert_to_internal(df).to_json(orient='records') or ''
110
+ return df_convert_to_internal(df).to_json(orient='records', date_unit='ns') or ''
111
+
112
+
113
+ def format_for_display(df: DataFrame) -> None:
114
+ """
115
+ Apply transformations to a DataFrame to make it suitable for display.
116
+ Not: this does NOT make a copy of the DataFrame
117
+ """
118
+ for col in df.columns:
119
+ column_data = df[col]
120
+ if isinstance(column_data, DataFrame):
121
+ # Handle duplicate column names - format each column in the sub-DataFrame
122
+ for sub_col in column_data.columns:
123
+ if isinstance(column_data[sub_col], Series) and column_data[sub_col].dtype == 'object':
124
+ column_data.loc[:, sub_col] = column_data[sub_col].apply(str)
125
+ elif column_data.dtype == 'object':
126
+ # We need to convert all values to string to avoid issues with
127
+ # displaying data in the Table component, for example when
128
+ # displaying datetime and number objects in the same column
129
+ df.loc[:, col] = column_data.apply(str)
130
+
131
+
132
+ class FieldType(TypedDict):
133
+ name: Union[str, tuple[str, ...]]
134
+ type: Literal['integer', 'number', 'boolean', 'datetime', 'duration', 'any', 'str']
135
+
136
+
137
+ class DataFrameSchema(TypedDict):
138
+ fields: list[FieldType]
139
+ primaryKey: list[str]
140
+
141
+
142
+ class DataResponse(TypedDict):
143
+ data: Optional[DataFrame]
144
+ count: int
145
+ schema: Optional[DataFrameSchema]
146
+
147
+
148
+ def is_data_response(response: Any) -> TypeGuard[DataResponse]:
149
+ has_shape = isinstance(response, dict) and 'data' in response and 'count' in response
150
+ if not has_shape:
151
+ return False
152
+ return response['data'] is None or isinstance(response['data'], DataFrame)
153
+
154
+
155
+ def data_response_to_json(response: DataResponse) -> str:
156
+ """
157
+ Serialize a DataResponse to JSON.
158
+
159
+ json.dumps() custom serializers only accept value->value mappings, whereas `to_json` on pandas returns a string directly.
160
+ To avoid double serialization, we first insert a placeholder string and then replace it with the actual serialized JSON.
161
+ """
162
+ placeholder = str(uuid.uuid4())
163
+
164
+ def _custom_serializer(obj: Any) -> Any:
165
+ if isinstance(obj, DataFrame):
166
+ return placeholder
167
+ raise TypeError(f'Object of type {type(obj)} is not JSON serializable')
168
+
169
+ result = json.dumps(response, default=_custom_serializer)
170
+ result = result.replace(
171
+ f'"{placeholder}"', df_to_json(response['data']) if response['data'] is not None else 'null'
172
+ )
173
+ return result
174
+
175
+
176
+ def build_data_response(data: DataFrame, count: int) -> DataResponse:
177
+ data_internal = df_convert_to_internal(data)
178
+ schema = get_schema(data_internal)
179
+
180
+ return DataResponse(data=data, count=count, schema=schema)
181
+
182
+
183
+ def get_schema(df: DataFrame):
184
+ from pandas.io.json._table_schema import build_table_schema
185
+
186
+ raw_schema = build_table_schema(df)
187
+
188
+ for field_data in cast(list, raw_schema['fields']):
189
+ if field_data.get('type') == 'datetime':
190
+ # for datetime fields we need to know the resolution, so we get the actual e.g. `datetime64[ns]` string
191
+ column_name = field_data.get('name')
192
+ dtype_str = str(df[column_name].dtype)
193
+ field_data['type'] = dtype_str
194
+
195
+ return cast(DataFrameSchema, raw_schema)
@@ -95,7 +95,7 @@ class TaskDefinition:
95
95
  def __await__(self):
96
96
  """Await the underlying event, then return or raise the result"""
97
97
  yield from self.event.wait().__await__()
98
- if isinstance(self.result, Exception):
98
+ if isinstance(self.result, BaseException):
99
99
  raise self.result
100
100
  return self.result
101
101
 
@@ -16,10 +16,11 @@ limitations under the License.
16
16
  """
17
17
 
18
18
  import atexit
19
+ from collections.abc import Coroutine
19
20
  from contextlib import contextmanager
20
21
  from datetime import datetime
21
22
  from multiprocessing import active_children
22
- from typing import Any, Callable, Coroutine, Dict, Optional, cast
23
+ from typing import Any, Callable, Dict, Optional, Union, cast
23
24
 
24
25
  import anyio
25
26
  from anyio.abc import TaskGroup
@@ -102,16 +103,18 @@ class TaskPool:
102
103
  try:
103
104
  await wait_while(
104
105
  lambda: self.status != PoolStatus.RUNNING
105
- or not len(self.workers) == self.desired_workers
106
+ or len(self.workers) != self.desired_workers
106
107
  or not all(w.status == WorkerStatus.IDLE for w in self.workers.values()),
107
108
  timeout=timeout,
108
109
  )
109
- except TimeoutError:
110
- raise RuntimeError('Failed to start pool')
110
+ except TimeoutError as e:
111
+ raise RuntimeError('Failed to start pool') from e
111
112
  else:
112
113
  raise RuntimeError('Pool already started')
113
114
 
114
- def submit(self, task_uid: str, function_name: str, args: tuple = (), kwargs: dict = {}) -> TaskDefinition:
115
+ def submit(
116
+ self, task_uid: str, function_name: str, args: Union[tuple, None] = None, kwargs: Union[dict, None] = None
117
+ ) -> TaskDefinition:
115
118
  """
116
119
  Submit a new task to the pool
117
120
 
@@ -120,6 +123,10 @@ class TaskPool:
120
123
  :param args: list of arguments to pass to the function
121
124
  :param kwargs: dict of kwargs to pass to the function
122
125
  """
126
+ if args is None:
127
+ args = ()
128
+ if kwargs is None:
129
+ kwargs = {}
123
130
  self._check_pool_state()
124
131
 
125
132
  # Create a task definition to keep track of its progress
@@ -151,7 +158,7 @@ class TaskPool:
151
158
 
152
159
  task = self.tasks.pop(task_uid)
153
160
  if not task.event.is_set():
154
- task.result = Exception('Task cancelled')
161
+ task.result = anyio.get_cancelled_exc_class()()
155
162
  task.event.set()
156
163
 
157
164
  # Task in progress, stop the worker
@@ -463,9 +470,8 @@ class TaskPool:
463
470
  )
464
471
  elif is_log(worker_msg):
465
472
  dev_logger.info(f'Task: {worker_msg["task_uid"]}', {'logs': worker_msg['log']})
466
- elif is_progress(worker_msg):
467
- if worker_msg['task_uid'] in self._progress_subscribers:
468
- await self._progress_subscribers[worker_msg['task_uid']](worker_msg['progress'], worker_msg['message'])
473
+ elif is_progress(worker_msg) and worker_msg['task_uid'] in self._progress_subscribers:
474
+ await self._progress_subscribers[worker_msg['task_uid']](worker_msg['progress'], worker_msg['message'])
469
475
 
470
476
  async def _wait_queue_depletion(self, timeout: Optional[float] = None):
471
477
  """
@@ -478,8 +484,8 @@ class TaskPool:
478
484
  condition=lambda: self.status in (PoolStatus.CLOSED, PoolStatus.RUNNING) and len(self.tasks) > 0,
479
485
  timeout=timeout,
480
486
  )
481
- except TimeoutError:
482
- raise TimeoutError('Tasks are still being executed')
487
+ except TimeoutError as e:
488
+ raise TimeoutError('Tasks are still being executed') from e
483
489
 
484
490
  async def _core_loop(self):
485
491
  """
@@ -493,11 +499,14 @@ class TaskPool:
493
499
  while self.status not in (PoolStatus.ERROR, PoolStatus.STOPPED):
494
500
  await anyio.sleep(0.1)
495
501
 
496
- self._handle_excess_workers()
497
- self._handle_orphaned_workers()
498
- self._handle_dead_workers()
499
- self._create_workers()
500
- await self._process_next_worker_message()
502
+ try:
503
+ self._handle_excess_workers()
504
+ self._handle_orphaned_workers()
505
+ self._handle_dead_workers()
506
+ self._create_workers()
507
+ await self._process_next_worker_message()
508
+ except Exception as e:
509
+ dev_logger.error('Error in task pool', e)
501
510
  finally:
502
511
  self.loop_stopped.set()
503
512
 
@@ -27,6 +27,8 @@ from typing import Any, Callable, Optional, Tuple
27
27
  import anyio
28
28
  from tblib import Traceback
29
29
 
30
+ from dara.core.logging import dev_logger
31
+
30
32
 
31
33
  class SubprocessException:
32
34
  """
@@ -89,9 +91,7 @@ def read_from_shared_memory(pointer: SharedMemoryPointer) -> Any:
89
91
  data = shared_mem.buf[:data_size]
90
92
 
91
93
  # Unpickle and deepcopy
92
- decoded_payload_shared = pickle.loads(
93
- shared_mem.buf
94
- ) # nosec B301 # we trust the shared memory pointer passed by the pool
94
+ decoded_payload_shared = pickle.loads(shared_mem.buf) # nosec B301 # we trust the shared memory pointer passed by the pool
95
95
  decoded_payload = copy.deepcopy(decoded_payload_shared)
96
96
 
97
97
  # Cleanup
@@ -133,20 +133,23 @@ async def stop_process_async(process: BaseProcess, timeout: float = 3):
133
133
  # Terminate and wait for it to shutdown
134
134
  process.terminate()
135
135
 
136
- # mimic process.join() in an async way to not block
137
- await wait_while(process.is_alive, timeout)
138
-
139
- # If it's still alive
140
- if process.is_alive():
141
- try:
142
- os.kill(process.pid, signal.SIGKILL)
143
- await wait_while(process.is_alive, timeout)
144
- except OSError:
136
+ try:
137
+ # mimic process.join() in an async way to not block
138
+ await wait_while(process.is_alive, timeout)
139
+
140
+ # If it's still alive
141
+ if process.is_alive():
142
+ try:
143
+ os.kill(process.pid, signal.SIGKILL)
144
+ await wait_while(process.is_alive, timeout)
145
+ except OSError as e:
146
+ raise RuntimeError(f'Unable to terminate subprocess with PID {process.pid}') from e
147
+
148
+ # If it's still alive raise an exception
149
+ if process.is_alive():
145
150
  raise RuntimeError(f'Unable to terminate subprocess with PID {process.pid}')
146
-
147
- # If it's still alive raise an exception
148
- if process.is_alive():
149
- raise RuntimeError(f'Unable to terminate subprocess with PID {process.pid}')
151
+ except Exception as e:
152
+ dev_logger.error('Error stopping process', e)
150
153
 
151
154
 
152
155
  def stop_process(process: BaseProcess, timeout: float = 3):
@@ -171,8 +174,8 @@ def stop_process(process: BaseProcess, timeout: float = 3):
171
174
  try:
172
175
  os.kill(process.pid, signal.SIGKILL)
173
176
  process.join(timeout)
174
- except OSError:
175
- raise RuntimeError(f'Unable to terminate subprocess with PID {process.pid}')
177
+ except OSError as e:
178
+ raise RuntimeError(f'Unable to terminate subprocess with PID {process.pid}') from e
176
179
 
177
180
  # If it's still alive raise an exception
178
181
  if process.is_alive():
@@ -57,7 +57,8 @@ class StdoutLogger:
57
57
  self.channel.worker_api.log(self.task_uid, msg)
58
58
 
59
59
  def flush(self):
60
- sys.__stdout__.flush()
60
+ if sys.__stdout__:
61
+ sys.__stdout__.flush()
61
62
 
62
63
 
63
64
  def execute_function(func: Callable, args: tuple, kwargs: dict):
@@ -164,7 +165,7 @@ def worker_loop(worker_params: WorkerParameters, channel: Channel):
164
165
 
165
166
  # Redirect logs via the channel
166
167
  stdout_logger = StdoutLogger(task_uid, channel)
167
- sys.stdout = stdout_logger # type: ignore
168
+ sys.stdout = stdout_logger # type: ignore
168
169
 
169
170
  try:
170
171
  payload_pointer = task['payload']
@@ -27,7 +27,7 @@ def is_available(host: str, port: int) -> bool:
27
27
  """
28
28
  try:
29
29
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
30
- sock.settimeout(2.0) # timeout in case port is blocked
30
+ sock.settimeout(2.0) # timeout in case port is blocked
31
31
  return sock.connect_ex((host, port)) != 0
32
32
  except BaseException:
33
33
  return False