dara-core 1.21.16__py3-none-any.whl → 1.21.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. dara/core/auth/base.py +5 -5
  2. dara/core/auth/basic.py +3 -3
  3. dara/core/auth/definitions.py +13 -14
  4. dara/core/auth/routes.py +7 -5
  5. dara/core/auth/utils.py +11 -10
  6. dara/core/base_definitions.py +30 -36
  7. dara/core/cli.py +7 -8
  8. dara/core/configuration.py +51 -58
  9. dara/core/css.py +2 -2
  10. dara/core/data_utils.py +12 -17
  11. dara/core/defaults.py +3 -3
  12. dara/core/definitions.py +58 -63
  13. dara/core/http.py +4 -4
  14. dara/core/interactivity/actions.py +34 -42
  15. dara/core/interactivity/any_data_variable.py +1 -1
  16. dara/core/interactivity/any_variable.py +6 -5
  17. dara/core/interactivity/client_variable.py +1 -2
  18. dara/core/interactivity/condition.py +2 -2
  19. dara/core/interactivity/data_variable.py +2 -4
  20. dara/core/interactivity/derived_data_variable.py +7 -10
  21. dara/core/interactivity/derived_variable.py +45 -51
  22. dara/core/interactivity/filtering.py +19 -19
  23. dara/core/interactivity/loop_variable.py +2 -4
  24. dara/core/interactivity/non_data_variable.py +1 -1
  25. dara/core/interactivity/plain_variable.py +21 -18
  26. dara/core/interactivity/server_variable.py +13 -15
  27. dara/core/interactivity/state_variable.py +4 -5
  28. dara/core/interactivity/switch_variable.py +16 -16
  29. dara/core/interactivity/tabular_variable.py +3 -3
  30. dara/core/interactivity/url_variable.py +3 -3
  31. dara/core/internal/cache_store/cache_store.py +6 -6
  32. dara/core/internal/cache_store/keep_all.py +3 -3
  33. dara/core/internal/cache_store/lru.py +8 -8
  34. dara/core/internal/cache_store/ttl.py +4 -4
  35. dara/core/internal/custom_response.py +3 -3
  36. dara/core/internal/dependency_resolution.py +6 -10
  37. dara/core/internal/devtools.py +2 -3
  38. dara/core/internal/download.py +5 -6
  39. dara/core/internal/encoder_registry.py +7 -11
  40. dara/core/internal/execute_action.py +5 -5
  41. dara/core/internal/hashing.py +1 -2
  42. dara/core/internal/import_discovery.py +7 -9
  43. dara/core/internal/normalization.py +12 -15
  44. dara/core/internal/pandas_utils.py +6 -6
  45. dara/core/internal/pool/channel.py +3 -4
  46. dara/core/internal/pool/definitions.py +9 -9
  47. dara/core/internal/pool/task_pool.py +8 -8
  48. dara/core/internal/pool/utils.py +4 -3
  49. dara/core/internal/pool/worker.py +3 -3
  50. dara/core/internal/registries.py +4 -4
  51. dara/core/internal/registry.py +3 -3
  52. dara/core/internal/registry_lookup.py +4 -4
  53. dara/core/internal/routing.py +23 -22
  54. dara/core/internal/scheduler.py +8 -8
  55. dara/core/internal/settings.py +1 -2
  56. dara/core/internal/store.py +9 -9
  57. dara/core/internal/tasks.py +30 -30
  58. dara/core/internal/utils.py +9 -15
  59. dara/core/internal/websocket.py +18 -18
  60. dara/core/js_tooling/js_utils.py +19 -19
  61. dara/core/logging.py +13 -13
  62. dara/core/main.py +4 -5
  63. dara/core/metrics/cache.py +2 -4
  64. dara/core/persistence.py +19 -25
  65. dara/core/router/compat.py +1 -3
  66. dara/core/router/components.py +10 -10
  67. dara/core/router/dependency_graph.py +2 -4
  68. dara/core/router/router.py +43 -42
  69. dara/core/visual/components/dynamic_component.py +1 -3
  70. dara/core/visual/components/fallback.py +3 -3
  71. dara/core/visual/components/for_cmp.py +5 -5
  72. dara/core/visual/components/menu.py +1 -3
  73. dara/core/visual/components/router_content.py +1 -3
  74. dara/core/visual/components/sidebar_frame.py +8 -10
  75. dara/core/visual/components/theme_provider.py +3 -3
  76. dara/core/visual/components/topbar_frame.py +8 -10
  77. dara/core/visual/css/__init__.py +277 -277
  78. dara/core/visual/dynamic_component.py +18 -22
  79. dara/core/visual/progress_updater.py +1 -1
  80. dara/core/visual/template.py +10 -12
  81. dara/core/visual/themes/definitions.py +46 -46
  82. {dara_core-1.21.16.dist-info → dara_core-1.21.17.dist-info}/METADATA +12 -13
  83. dara_core-1.21.17.dist-info/RECORD +127 -0
  84. dara_core-1.21.16.dist-info/RECORD +0 -127
  85. {dara_core-1.21.16.dist-info → dara_core-1.21.17.dist-info}/LICENSE +0 -0
  86. {dara_core-1.21.16.dist-info → dara_core-1.21.17.dist-info}/WHEEL +0 -0
  87. {dara_core-1.21.16.dist-info → dara_core-1.21.17.dist-info}/entry_points.txt +0 -0
@@ -15,9 +15,9 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
- from typing import Any, List, Literal, Optional, Union
18
+ from typing import Any, Literal, TypeGuard
19
19
 
20
- from typing_extensions import TypedDict, TypeGuard
20
+ from typing_extensions import TypedDict
21
21
 
22
22
  from dara.core.base_definitions import BaseTask, PendingTask
23
23
  from dara.core.interactivity import DerivedVariable
@@ -31,8 +31,8 @@ from dara.core.logging import dev_logger
31
31
  class ResolvedDerivedVariable(TypedDict):
32
32
  type: Literal['derived']
33
33
  uid: str
34
- values: List[Any]
35
- force_key: Optional[str]
34
+ values: list[Any]
35
+ force_key: str | None
36
36
 
37
37
 
38
38
  class ResolvedServerVariable(TypedDict):
@@ -90,11 +90,7 @@ def clean_force_key(value: Any) -> Any:
90
90
 
91
91
 
92
92
  async def resolve_dependency(
93
- entry: Union[
94
- ResolvedDerivedVariable,
95
- ResolvedSwitchVariable,
96
- Any,
97
- ],
93
+ entry: ResolvedDerivedVariable | ResolvedSwitchVariable | Any,
98
94
  store: CacheStore,
99
95
  task_mgr: TaskManager,
100
96
  ):
@@ -135,7 +131,7 @@ async def _resolve_derived_var(
135
131
 
136
132
  registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
137
133
  var = await registry_mgr.get(derived_variable_registry, str(derived_variable_entry.get('uid')))
138
- input_values: List[Any] = derived_variable_entry.get('values', [])
134
+ input_values: list[Any] = derived_variable_entry.get('values', [])
139
135
  result = await DerivedVariable.get_value(
140
136
  var_entry=var,
141
137
  store=store,
@@ -21,13 +21,12 @@ import sys
21
21
  import traceback
22
22
  from contextlib import contextmanager
23
23
  from datetime import datetime
24
- from typing import Optional
25
24
 
26
25
  from dara.core.internal.websocket import WebsocketManager
27
26
  from dara.core.logging import eng_logger
28
27
 
29
28
 
30
- def print_stacktrace(err: Optional[BaseException] = None) -> str:
29
+ def print_stacktrace(err: BaseException | None = None) -> str:
31
30
  """
32
31
  Prints out the current stack trace. Will also extract any exceptions and print them at the end.
33
32
  """
@@ -59,7 +58,7 @@ def handle_system_exit(error_msg: str):
59
58
  raise InterruptedError(error_msg) from e
60
59
 
61
60
 
62
- def get_error_for_channel(err: Optional[BaseException] = None) -> dict:
61
+ def get_error_for_channel(err: BaseException | None = None) -> dict:
63
62
  """
64
63
  Get error from current stacktrace to send to the client
65
64
  """
@@ -18,9 +18,8 @@ limitations under the License.
18
18
  from __future__ import annotations
19
19
 
20
20
  import os
21
- from collections.abc import Awaitable
21
+ from collections.abc import Awaitable, Callable
22
22
  from contextvars import ContextVar
23
- from typing import Callable, Optional, Tuple
24
23
  from uuid import uuid4
25
24
 
26
25
  import anyio
@@ -38,8 +37,8 @@ class DownloadDataEntry(BaseModel):
38
37
  uid: str
39
38
  file_path: str
40
39
  cleanup_file: bool
41
- identity_name: Optional[str] = None
42
- download: Callable[[DownloadDataEntry], Awaitable[Tuple[anyio.AsyncFile, Callable[..., Awaitable]]]]
40
+ identity_name: str | None = None
41
+ download: Callable[[DownloadDataEntry], Awaitable[tuple[anyio.AsyncFile, Callable[..., Awaitable]]]]
43
42
  """Handler for getting the file from the entry"""
44
43
 
45
44
 
@@ -48,7 +47,7 @@ DownloadRegistryEntry = CachedRegistryEntry(
48
47
  ) # expire the codes after 10 minutes
49
48
 
50
49
 
51
- async def download(data_entry: DownloadDataEntry) -> Tuple[anyio.AsyncFile, Callable[..., Awaitable]]:
50
+ async def download(data_entry: DownloadDataEntry) -> tuple[anyio.AsyncFile, Callable[..., Awaitable]]:
52
51
  """
53
52
  Get the loaded filename and path from a code
54
53
 
@@ -74,7 +73,7 @@ async def download(data_entry: DownloadDataEntry) -> Tuple[anyio.AsyncFile, Call
74
73
  return (async_file, cleanup)
75
74
 
76
75
 
77
- GENERATE_CODE_OVERRIDE = ContextVar[Optional[Callable[[str], str]]]('GENERATE_CODE_OVERRIDE', default=None)
76
+ GENERATE_CODE_OVERRIDE = ContextVar[Callable[[str], str] | None]('GENERATE_CODE_OVERRIDE', default=None)
78
77
  """
79
78
  Optional context variable which can be used to override the default behaviour of code generation.
80
79
  Invoked with the file path to generate a download code for.
@@ -15,14 +15,10 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
- from collections.abc import MutableMapping
18
+ from collections.abc import Callable, MutableMapping
19
19
  from inspect import Parameter, isclass
20
20
  from typing import (
21
21
  Any,
22
- Callable,
23
- Dict,
24
- Optional,
25
- Type,
26
22
  Union,
27
23
  get_args,
28
24
  get_origin,
@@ -47,7 +43,7 @@ def _not_implemented(x, dtype):
47
43
  raise NotImplementedError(f'No deserialization implementation for item {x} of dtype {dtype}')
48
44
 
49
45
 
50
- def _get_numpy_dtypes_encoder(typ: Type[Any]):
46
+ def _get_numpy_dtypes_encoder(typ: type[Any]):
51
47
  """
52
48
  Construct numpy generic datatype
53
49
 
@@ -56,7 +52,7 @@ def _get_numpy_dtypes_encoder(typ: Type[Any]):
56
52
  return Encoder(serialize=lambda x: x.item(), deserialize=lambda x: typ(x))
57
53
 
58
54
 
59
- def _get_numpy_str_encoder(typ: Type[Any]):
55
+ def _get_numpy_str_encoder(typ: type[Any]):
60
56
  """
61
57
  Construct numpy str datatype
62
58
 
@@ -65,7 +61,7 @@ def _get_numpy_str_encoder(typ: Type[Any]):
65
61
  return Encoder(serialize=lambda x: str(x), deserialize=lambda x: typ(x))
66
62
 
67
63
 
68
- def _get_pandas_array_encoder(array_type: Type[Any], dtype: Any, raise_: bool = False):
64
+ def _get_pandas_array_encoder(array_type: type[Any], dtype: Any, raise_: bool = False):
69
65
  return Encoder(
70
66
  serialize=lambda x: x.astype(str).tolist(),
71
67
  deserialize=lambda x: pandas.array(x, dtype=dtype) if not raise_ else _not_implemented(x, dtype),
@@ -127,7 +123,7 @@ def _df_deserialize(x):
127
123
 
128
124
 
129
125
  # A encoder_registry to handle serialization/deserialization for numpy/pandas type
130
- encoder_registry: MutableMapping[Type[Any], Encoder] = {
126
+ encoder_registry: MutableMapping[type[Any], Encoder] = {
131
127
  int: Encoder(serialize=lambda x: x, deserialize=lambda x: int(x)),
132
128
  float: Encoder(serialize=lambda x: x, deserialize=lambda x: float(x)),
133
129
  str: Encoder(serialize=lambda x: x, deserialize=lambda x: str(x)),
@@ -211,14 +207,14 @@ else:
211
207
  )
212
208
 
213
209
 
214
- def get_jsonable_encoder() -> Dict[Type[Any], Callable[..., Any]]:
210
+ def get_jsonable_encoder() -> dict[type[Any], Callable[..., Any]]:
215
211
  """
216
212
  Get the encoder registry as a dict of `{type: serialize}` pairs
217
213
  """
218
214
  return {k: v['serialize'] for k, v in encoder_registry.items()}
219
215
 
220
216
 
221
- def deserialize(value: Any, typ: Optional[Type]):
217
+ def deserialize(value: Any, typ: type | None):
222
218
  """
223
219
  Deserialize a value into a given type.
224
220
 
@@ -18,10 +18,10 @@ limitations under the License.
18
18
  from __future__ import annotations
19
19
 
20
20
  import asyncio
21
- from collections.abc import Mapping
21
+ from collections.abc import Callable, Mapping
22
22
  from contextvars import ContextVar
23
23
  from functools import partial
24
- from typing import Any, Callable, Literal, Optional, Union
24
+ from typing import Any, Literal
25
25
 
26
26
  import anyio
27
27
 
@@ -132,7 +132,7 @@ async def execute_action_sync(
132
132
  results = []
133
133
 
134
134
  # Construct a context which handles action messages by accumulating them in an array
135
- async def handle_action(act_impl: Optional[ActionImpl]):
135
+ async def handle_action(act_impl: ActionImpl | None):
136
136
  if act_impl is not None:
137
137
  results.append(act_impl)
138
138
 
@@ -176,7 +176,7 @@ async def execute_action(
176
176
  ws_channel: str,
177
177
  store: CacheStore,
178
178
  task_mgr: TaskManager,
179
- ) -> Union[Any, BaseTask]:
179
+ ) -> Any | BaseTask:
180
180
  """
181
181
  Execute a given action with the provided context.
182
182
 
@@ -201,7 +201,7 @@ async def execute_action(
201
201
  assert action is not None, 'Action resolver must be defined'
202
202
 
203
203
  # Construct a context which handles action messages by sending them to the frontend
204
- async def handle_action(act_impl: Optional[ActionImpl]):
204
+ async def handle_action(act_impl: ActionImpl | None):
205
205
  await ws_mgr.send_message(ws_channel, {'action': act_impl, 'uid': execution_id})
206
206
 
207
207
  ctx = ActionCtx(inp, handle_action)
@@ -17,12 +17,11 @@ limitations under the License.
17
17
 
18
18
  import hashlib
19
19
  import json
20
- from typing import Union
21
20
 
22
21
  from pydantic import BaseModel
23
22
 
24
23
 
25
- def hash_object(obj: Union[BaseModel, dict, None]):
24
+ def hash_object(obj: BaseModel | dict | None):
26
25
  """
27
26
  Create a unique hash for the object.
28
27
 
@@ -18,9 +18,7 @@ limitations under the License.
18
18
  import inspect
19
19
  import sys
20
20
  from types import ModuleType
21
- from typing import Any, List, Optional, Set, Tuple, Type, Union
22
-
23
- from typing_extensions import TypeGuard
21
+ from typing import Any, TypeGuard
24
22
 
25
23
  from dara.core.base_definitions import ActionDef, ActionImpl
26
24
  from dara.core.definitions import ComponentInstance, JsComponentDef, discover
@@ -44,7 +42,7 @@ def _is_action_subclass(obj: Any) -> bool:
44
42
  return inspect.isclass(obj) and issubclass(obj, ActionImpl) and obj != ActionImpl
45
43
 
46
44
 
47
- def is_ignored(symbol: Any, ignore_symbols: List[Any]) -> bool:
45
+ def is_ignored(symbol: Any, ignore_symbols: list[Any]) -> bool:
48
46
  """
49
47
  Check whether a symbol should be ignored
50
48
 
@@ -64,8 +62,8 @@ def is_ignored(symbol: Any, ignore_symbols: List[Any]) -> bool:
64
62
 
65
63
 
66
64
  def run_discovery(
67
- module: Union[ModuleType, dict], ignore_symbols: Optional[List[Any]] = None, **kwargs
68
- ) -> Tuple[Set[Type[ComponentInstance]], Set[Type[ActionImpl]]]:
65
+ module: ModuleType | dict, ignore_symbols: list[Any] | None = None, **kwargs
66
+ ) -> tuple[set[type[ComponentInstance]], set[type[ActionImpl]]]:
69
67
  """
70
68
  Recursively discover components available in the global namespace within the module
71
69
  and its child modules.
@@ -167,7 +165,7 @@ def run_discovery(
167
165
  return components, actions
168
166
 
169
167
 
170
- def _get_symbol_module(symbol: Union[Type[ComponentInstance], Type[ActionImpl]]) -> str:
168
+ def _get_symbol_module(symbol: type[ComponentInstance] | type[ActionImpl]) -> str:
171
169
  """Get the root module of the component or action"""
172
170
  comp_module = symbol.__module__
173
171
 
@@ -182,7 +180,7 @@ def _get_symbol_module(symbol: Union[Type[ComponentInstance], Type[ActionImpl]])
182
180
  return comp_module
183
181
 
184
182
 
185
- def create_component_definition(component: Type[ComponentInstance], local: bool = False):
183
+ def create_component_definition(component: type[ComponentInstance], local: bool = False):
186
184
  """
187
185
  Create a JsComponentDef for a given component class.
188
186
 
@@ -203,7 +201,7 @@ def create_component_definition(component: Type[ComponentInstance], local: bool
203
201
  )
204
202
 
205
203
 
206
- def create_action_definition(action: Type[ActionImpl], local: bool = False):
204
+ def create_action_definition(action: type[ActionImpl], local: bool = False):
207
205
  """
208
206
  Create a ActionDef for a given action class.
209
207
 
@@ -19,20 +19,17 @@ from collections.abc import Mapping
19
19
  from typing import (
20
20
  Any,
21
21
  Generic,
22
- List,
23
- Optional,
24
- Tuple,
22
+ TypeGuard,
25
23
  TypeVar,
26
- Union,
27
24
  cast,
28
25
  overload,
29
26
  )
30
27
 
31
- from typing_extensions import TypedDict, TypeGuard
28
+ from typing_extensions import TypedDict
32
29
 
33
30
  from dara.core.base_definitions import DaraBaseModel as BaseModel
34
31
 
35
- JsonLike = Union[Mapping, List]
32
+ JsonLike = Mapping | list
36
33
 
37
34
  DataType = TypeVar('DataType')
38
35
 
@@ -60,7 +57,7 @@ class Referrable(TypedDict):
60
57
 
61
58
 
62
59
  class ReferrableWithNested(Referrable):
63
- nested: List[str]
60
+ nested: list[str]
64
61
 
65
62
 
66
63
  class ReferrableWithFilters(Referrable):
@@ -77,7 +74,7 @@ def _get_identifier(obj: Referrable) -> str:
77
74
 
78
75
  # If it's a Variable with 'nested', the property should be included in the identifier
79
76
  if _is_referrable_nested(obj) and len(obj['nested']) > 0:
80
- nested = ','.join(cast(List[str], obj['nested']))
77
+ nested = ','.join(cast(list[str], obj['nested']))
81
78
  identifier = f'{identifier}:{nested}'
82
79
 
83
80
  return identifier
@@ -128,14 +125,14 @@ def _loop(iterable: JsonLike):
128
125
 
129
126
 
130
127
  @overload
131
- def normalize(obj: Mapping, check_root: bool = True) -> Tuple[Mapping, Mapping]: ...
128
+ def normalize(obj: Mapping, check_root: bool = True) -> tuple[Mapping, Mapping]: ...
132
129
 
133
130
 
134
131
  @overload
135
- def normalize(obj: List, check_root: bool = True) -> Tuple[List, Mapping]: ...
132
+ def normalize(obj: list, check_root: bool = True) -> tuple[list, Mapping]: ...
136
133
 
137
134
 
138
- def normalize(obj: JsonLike, check_root: bool = True) -> Tuple[JsonLike, Mapping]:
135
+ def normalize(obj: JsonLike, check_root: bool = True) -> tuple[JsonLike, Mapping]:
139
136
  """
140
137
  Normalize a dictionary - extract referrable data into a separate lookup dictionary, replacing instances
141
138
  found with placeholders.
@@ -148,7 +145,7 @@ def normalize(obj: JsonLike, check_root: bool = True) -> Tuple[JsonLike, Mapping
148
145
  if not isinstance(obj, (dict, list)):
149
146
  return obj, lookup
150
147
 
151
- output: Union[Mapping[Any, Any], List[Any]] = {} if isinstance(obj, dict) else [None for x in range(len(obj))]
148
+ output: Mapping[Any, Any] | list[Any] = {} if isinstance(obj, dict) else [None for x in range(len(obj))]
152
149
 
153
150
  # The whole object is referrable
154
151
  if check_root and _is_referrable(obj):
@@ -177,10 +174,10 @@ def denormalize(normalized_obj: Mapping, lookup: Mapping) -> Mapping: ...
177
174
 
178
175
 
179
176
  @overload
180
- def denormalize(normalized_obj: List, lookup: Mapping) -> List: ...
177
+ def denormalize(normalized_obj: list, lookup: Mapping) -> list: ...
181
178
 
182
179
 
183
- def denormalize(normalized_obj: JsonLike, lookup: Mapping) -> Optional[JsonLike]:
180
+ def denormalize(normalized_obj: JsonLike, lookup: Mapping) -> JsonLike | None:
184
181
  """
185
182
  Denormalize data by replacing Placeholders found with objects from the lookup
186
183
 
@@ -190,7 +187,7 @@ def denormalize(normalized_obj: JsonLike, lookup: Mapping) -> Optional[JsonLike]
190
187
  if normalized_obj is None:
191
188
  return None
192
189
 
193
- output: Union[Mapping[Any, Any], List[Any]] = (
190
+ output: Mapping[Any, Any] | list[Any] = (
194
191
  {} if isinstance(normalized_obj, dict) else [None for x in range(len(normalized_obj))]
195
192
  )
196
193
 
@@ -17,10 +17,10 @@ limitations under the License.
17
17
 
18
18
  import json
19
19
  import uuid
20
- from typing import Any, Literal, Optional, TypeVar, Union, cast, overload
20
+ from typing import Any, Literal, TypeGuard, TypeVar, cast, overload
21
21
 
22
22
  from pandas import DataFrame, MultiIndex, Series
23
- from typing_extensions import TypedDict, TypeGuard
23
+ from typing_extensions import TypedDict
24
24
 
25
25
  INDEX = '__index__'
26
26
 
@@ -33,7 +33,7 @@ def append_index(df: DataFrame) -> DataFrame: ...
33
33
  def append_index(df: None) -> None: ...
34
34
 
35
35
 
36
- def append_index(df: Optional[DataFrame]) -> Optional[DataFrame]:
36
+ def append_index(df: DataFrame | None) -> DataFrame | None:
37
37
  """
38
38
  Add a numerical index column to the dataframe
39
39
  """
@@ -130,7 +130,7 @@ def format_for_display(df: DataFrame) -> None:
130
130
 
131
131
 
132
132
  class FieldType(TypedDict):
133
- name: Union[str, tuple[str, ...]]
133
+ name: str | tuple[str, ...]
134
134
  type: Literal['integer', 'number', 'boolean', 'datetime', 'duration', 'any', 'str']
135
135
 
136
136
 
@@ -140,9 +140,9 @@ class DataFrameSchema(TypedDict):
140
140
 
141
141
 
142
142
  class DataResponse(TypedDict):
143
- data: Optional[DataFrame]
143
+ data: DataFrame | None
144
144
  count: int
145
- schema: Optional[DataFrameSchema]
145
+ schema: DataFrameSchema | None
146
146
 
147
147
 
148
148
  def is_data_response(response: Any) -> TypeGuard[DataResponse]:
@@ -20,7 +20,6 @@ from __future__ import annotations
20
20
  import os
21
21
  from multiprocessing import Queue, get_context
22
22
  from queue import Empty
23
- from typing import Optional
24
23
 
25
24
  from dara.core.internal.pool.definitions import (
26
25
  Acknowledgement,
@@ -56,7 +55,7 @@ class _PoolAPI:
56
55
 
57
56
  def get_worker_message(
58
57
  self,
59
- ) -> Optional[WorkerMessage]:
58
+ ) -> WorkerMessage | None:
60
59
  """
61
60
  Retrieve a worker message if there is one available
62
61
 
@@ -100,7 +99,7 @@ class _WorkerAPI:
100
99
  """
101
100
  self._out_queue.put(Result(task_uid=task_uid, result=result))
102
101
 
103
- def send_error(self, task_uid: Optional[str], error: BaseException):
102
+ def send_error(self, task_uid: str | None, error: BaseException):
104
103
  """
105
104
  Send an error back to the pool
106
105
 
@@ -131,7 +130,7 @@ class _WorkerAPI:
131
130
  """
132
131
  self._out_queue.put(Progress(task_uid=task_uid, progress=progress, message=message))
133
132
 
134
- def get_task(self) -> Optional[WorkerTask]:
133
+ def get_task(self) -> WorkerTask | None:
135
134
  """
136
135
  Retrieve a task definition from the worker queue if there is one available
137
136
 
@@ -17,10 +17,10 @@ limitations under the License.
17
17
 
18
18
  from datetime import datetime
19
19
  from enum import Enum
20
- from typing import Any, Optional, Union
20
+ from typing import Any, TypeGuard
21
21
 
22
22
  from anyio import Event
23
- from typing_extensions import TypedDict, TypeGuard
23
+ from typing_extensions import TypedDict
24
24
 
25
25
  from dara.core.internal.pool.utils import SharedMemoryPointer, SubprocessException
26
26
 
@@ -75,16 +75,16 @@ class TaskDefinition:
75
75
  event: Event
76
76
  result: Any
77
77
  payload: TaskPayload
78
- worker_id: Optional[int] = None
79
- started_at: Optional[datetime] = None
78
+ worker_id: int | None = None
79
+ started_at: datetime | None = None
80
80
  """TODO: can be used for task timeout or metrics/visibility"""
81
81
 
82
82
  def __init__(
83
83
  self,
84
84
  uid: str,
85
85
  payload: TaskPayload,
86
- worker_id: Optional[int] = None,
87
- started_at: Optional[datetime] = None,
86
+ worker_id: int | None = None,
87
+ started_at: datetime | None = None,
88
88
  ):
89
89
  self.uid = uid
90
90
  self.payload = payload
@@ -135,14 +135,14 @@ class Result(TypedDict):
135
135
  class Problem(TypedDict):
136
136
  """Sent when a worker encounters an issue processing a task"""
137
137
 
138
- task_uid: Optional[str]
138
+ task_uid: str | None
139
139
  error: SubprocessException
140
140
 
141
141
 
142
142
  class Log(TypedDict):
143
143
  """Sent when a task emits a stdout message"""
144
144
 
145
- task_uid: Optional[str]
145
+ task_uid: str | None
146
146
  log: str
147
147
 
148
148
 
@@ -154,7 +154,7 @@ class Progress(TypedDict):
154
154
  message: str
155
155
 
156
156
 
157
- WorkerMessage = Union[Acknowledgement, Result, Problem, Initialization, Log, Progress]
157
+ WorkerMessage = Acknowledgement | Result | Problem | Initialization | Log | Progress
158
158
  """Union of possible messages sent from worker processes"""
159
159
 
160
160
 
@@ -16,11 +16,11 @@ limitations under the License.
16
16
  """
17
17
 
18
18
  import atexit
19
- from collections.abc import Coroutine
19
+ from collections.abc import Callable, Coroutine
20
20
  from contextlib import contextmanager
21
21
  from datetime import datetime
22
22
  from multiprocessing import active_children
23
- from typing import Any, Callable, Dict, Optional, Union, cast
23
+ from typing import Any, cast
24
24
 
25
25
  import anyio
26
26
  from anyio.abc import TaskGroup
@@ -61,8 +61,8 @@ class TaskPool:
61
61
  """Number of seconds worker is allowed to be idle before it is killed, if there are too many workers alive"""
62
62
 
63
63
  worker_parameters: WorkerParameters
64
- workers: Dict[int, WorkerProcess] = {}
65
- tasks: Dict[str, TaskDefinition] = {}
64
+ workers: dict[int, WorkerProcess] = {}
65
+ tasks: dict[str, TaskDefinition] = {}
66
66
 
67
67
  def __init__(
68
68
  self, task_group: TaskGroup, worker_parameters: WorkerParameters, max_workers: int, worker_timeout: float = 5
@@ -75,7 +75,7 @@ class TaskPool:
75
75
  self.worker_timeout = worker_timeout
76
76
 
77
77
  self._channel = Channel()
78
- self._progress_subscribers: Dict[str, Callable[[float, str], Coroutine]] = {}
78
+ self._progress_subscribers: dict[str, Callable[[float, str], Coroutine]] = {}
79
79
 
80
80
  @property
81
81
  def running_tasks(self):
@@ -113,7 +113,7 @@ class TaskPool:
113
113
  raise RuntimeError('Pool already started')
114
114
 
115
115
  def submit(
116
- self, task_uid: str, function_name: str, args: Union[tuple, None] = None, kwargs: Union[dict, None] = None
116
+ self, task_uid: str, function_name: str, args: tuple | None = None, kwargs: dict | None = None
117
117
  ) -> TaskDefinition:
118
118
  """
119
119
  Submit a new task to the pool
@@ -196,7 +196,7 @@ class TaskPool:
196
196
  await self.loop_stopped.wait()
197
197
  await self._terminate_workers()
198
198
 
199
- async def join(self, timeout: Optional[float] = None):
199
+ async def join(self, timeout: float | None = None):
200
200
  """
201
201
  Join the pool and wait for workers to complete
202
202
 
@@ -473,7 +473,7 @@ class TaskPool:
473
473
  elif is_progress(worker_msg) and worker_msg['task_uid'] in self._progress_subscribers:
474
474
  await self._progress_subscribers[worker_msg['task_uid']](worker_msg['progress'], worker_msg['message'])
475
475
 
476
- async def _wait_queue_depletion(self, timeout: Optional[float] = None):
476
+ async def _wait_queue_depletion(self, timeout: float | None = None):
477
477
  """
478
478
  Wait until all tasks have been marked as completed
479
479
 
@@ -20,9 +20,10 @@ import os
20
20
  import pickle
21
21
  import signal
22
22
  import sys
23
+ from collections.abc import Callable
23
24
  from multiprocessing.process import BaseProcess
24
25
  from multiprocessing.shared_memory import SharedMemory
25
- from typing import Any, Callable, Optional, Tuple
26
+ from typing import Any
26
27
 
27
28
  import anyio
28
29
  from tblib import Traceback
@@ -48,7 +49,7 @@ class SubprocessException:
48
49
  return self.exception.with_traceback(tb)
49
50
 
50
51
 
51
- SharedMemoryPointer = Tuple[str, int]
52
+ SharedMemoryPointer = tuple[str, int]
52
53
 
53
54
 
54
55
  class PicklingException(Exception):
@@ -106,7 +107,7 @@ def read_from_shared_memory(pointer: SharedMemoryPointer) -> Any:
106
107
  raise PicklingException(*e.args) from e
107
108
 
108
109
 
109
- async def wait_while(condition: Callable[[], bool], timeout: Optional[float] = None):
110
+ async def wait_while(condition: Callable[[], bool], timeout: float | None = None):
110
111
  """
111
112
  Util to wait until a condition is False or timeout is exceeded
112
113
 
@@ -21,13 +21,13 @@ import logging
21
21
  import os
22
22
  import signal
23
23
  import sys
24
+ from collections.abc import Callable
24
25
  from datetime import datetime
25
26
  from importlib import import_module
26
27
  from inspect import iscoroutinefunction
27
28
  from multiprocessing import get_context
28
29
  from multiprocessing.context import SpawnProcess
29
30
  from time import sleep
30
- from typing import Callable, Optional
31
31
 
32
32
  import anyio
33
33
 
@@ -195,7 +195,7 @@ class WorkerProcess:
195
195
 
196
196
  status: WorkerStatus
197
197
 
198
- task_uid: Optional[str] = None
198
+ task_uid: str | None = None
199
199
  """Current task UID being processed by the worker"""
200
200
 
201
201
  channel: Channel
@@ -213,7 +213,7 @@ class WorkerProcess:
213
213
  self.process = ctx.Process(target=worker_loop, args=(worker_params, channel), name=WORKER_NAME)
214
214
  self.process.start()
215
215
 
216
- def update_status(self, worker_status: WorkerStatus, task_uid: Optional[str] = None):
216
+ def update_status(self, worker_status: WorkerStatus, task_uid: str | None = None):
217
217
  self.status = worker_status
218
218
  self.updated_at = datetime.now()
219
219
  self.task_uid = task_uid
@@ -15,9 +15,9 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
- from collections.abc import Mapping
18
+ from collections.abc import Callable, Mapping
19
19
  from datetime import datetime
20
- from typing import Any, Callable, Set
20
+ from typing import Any
21
21
 
22
22
  from dara.core.auth import BaseAuthConfig
23
23
  from dara.core.base_definitions import ActionDef, ActionResolverDef, UploadResolverDef
@@ -55,10 +55,10 @@ auth_registry = Registry[BaseAuthConfig](RegistryType.AUTH_CONFIG)
55
55
  utils_registry = Registry[Any](RegistryType.UTILS, INITIAL_CORE_INTERNALS)
56
56
  static_kwargs_registry = Registry[Mapping[str, Any]](RegistryType.STATIC_KWARGS)
57
57
 
58
- websocket_registry = Registry[Set[str]](RegistryType.WEBSOCKET_CHANNELS)
58
+ websocket_registry = Registry[set[str]](RegistryType.WEBSOCKET_CHANNELS)
59
59
  """maps session_id -> WS channel"""
60
60
 
61
- sessions_registry = Registry[Set[str]](RegistryType.USER_SESSION)
61
+ sessions_registry = Registry[set[str]](RegistryType.USER_SESSION)
62
62
  """maps user_identifier -> session_ids """
63
63
 
64
64
  pending_tokens_registry = Registry[datetime](RegistryType.PENDING_TOKENS)