dara-core 1.17.6__py3-none-any.whl → 1.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dara/core/__init__.py +2 -0
- dara/core/actions.py +1 -2
- dara/core/auth/basic.py +9 -9
- dara/core/auth/routes.py +5 -5
- dara/core/auth/utils.py +4 -4
- dara/core/base_definitions.py +15 -22
- dara/core/cli.py +8 -7
- dara/core/configuration.py +5 -2
- dara/core/css.py +1 -2
- dara/core/data_utils.py +2 -2
- dara/core/defaults.py +4 -7
- dara/core/definitions.py +6 -9
- dara/core/http.py +7 -3
- dara/core/interactivity/actions.py +28 -30
- dara/core/interactivity/any_data_variable.py +6 -5
- dara/core/interactivity/any_variable.py +4 -7
- dara/core/interactivity/data_variable.py +1 -1
- dara/core/interactivity/derived_data_variable.py +7 -6
- dara/core/interactivity/derived_variable.py +93 -33
- dara/core/interactivity/filtering.py +19 -27
- dara/core/interactivity/plain_variable.py +3 -2
- dara/core/interactivity/switch_variable.py +4 -4
- dara/core/internal/cache_store/base_impl.py +2 -1
- dara/core/internal/cache_store/cache_store.py +17 -5
- dara/core/internal/cache_store/keep_all.py +4 -1
- dara/core/internal/cache_store/lru.py +5 -1
- dara/core/internal/cache_store/ttl.py +4 -1
- dara/core/internal/cgroup.py +1 -1
- dara/core/internal/dependency_resolution.py +46 -10
- dara/core/internal/devtools.py +2 -2
- dara/core/internal/download.py +4 -3
- dara/core/internal/encoder_registry.py +7 -7
- dara/core/internal/execute_action.py +4 -10
- dara/core/internal/hashing.py +1 -3
- dara/core/internal/import_discovery.py +3 -4
- dara/core/internal/normalization.py +9 -13
- dara/core/internal/pandas_utils.py +3 -3
- dara/core/internal/pool/task_pool.py +16 -10
- dara/core/internal/pool/utils.py +5 -7
- dara/core/internal/pool/worker.py +3 -2
- dara/core/internal/port_utils.py +1 -1
- dara/core/internal/registries.py +9 -4
- dara/core/internal/registry.py +3 -1
- dara/core/internal/registry_lookup.py +7 -3
- dara/core/internal/routing.py +77 -44
- dara/core/internal/scheduler.py +13 -8
- dara/core/internal/settings.py +2 -2
- dara/core/internal/tasks.py +8 -14
- dara/core/internal/utils.py +11 -10
- dara/core/internal/websocket.py +18 -19
- dara/core/js_tooling/js_utils.py +23 -24
- dara/core/logging.py +3 -6
- dara/core/main.py +14 -11
- dara/core/metrics/cache.py +1 -1
- dara/core/metrics/utils.py +3 -3
- dara/core/persistence.py +1 -1
- dara/core/umd/dara.core.umd.js +146 -128
- dara/core/visual/components/__init__.py +2 -2
- dara/core/visual/components/fallback.py +3 -3
- dara/core/visual/css/__init__.py +30 -31
- dara/core/visual/dynamic_component.py +10 -11
- dara/core/visual/progress_updater.py +4 -3
- {dara_core-1.17.6.dist-info → dara_core-1.18.0.dist-info}/METADATA +10 -10
- dara_core-1.18.0.dist-info/RECORD +114 -0
- dara_core-1.17.6.dist-info/RECORD +0 -114
- {dara_core-1.17.6.dist-info → dara_core-1.18.0.dist-info}/LICENSE +0 -0
- {dara_core-1.17.6.dist-info → dara_core-1.18.0.dist-info}/WHEEL +0 -0
- {dara_core-1.17.6.dist-info → dara_core-1.18.0.dist-info}/entry_points.txt +0 -0
|
@@ -15,7 +15,8 @@ See the License for the specific language governing permissions and
|
|
|
15
15
|
limitations under the License.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
-
from
|
|
18
|
+
from collections.abc import Coroutine
|
|
19
|
+
from typing import Callable, Dict, Literal, Union
|
|
19
20
|
|
|
20
21
|
from dara.core.internal.registry import Registry, RegistryType
|
|
21
22
|
from dara.core.internal.utils import async_dedupe
|
|
@@ -28,6 +29,7 @@ RegistryLookupKey = Literal[
|
|
|
28
29
|
RegistryType.STATIC_KWARGS,
|
|
29
30
|
RegistryType.UPLOAD_RESOLVER,
|
|
30
31
|
RegistryType.BACKEND_STORE,
|
|
32
|
+
RegistryType.DOWNLOAD_CODE,
|
|
31
33
|
]
|
|
32
34
|
CustomRegistryLookup = Dict[RegistryLookupKey, Callable[[str], Coroutine]]
|
|
33
35
|
|
|
@@ -37,7 +39,9 @@ class RegistryLookup:
|
|
|
37
39
|
Manages registry Lookup.
|
|
38
40
|
"""
|
|
39
41
|
|
|
40
|
-
def __init__(self, handlers: CustomRegistryLookup =
|
|
42
|
+
def __init__(self, handlers: Union[CustomRegistryLookup, None] = None):
|
|
43
|
+
if handlers is None:
|
|
44
|
+
handlers = {}
|
|
41
45
|
self.handlers = handlers
|
|
42
46
|
|
|
43
47
|
@async_dedupe
|
|
@@ -62,4 +66,4 @@ class RegistryLookup:
|
|
|
62
66
|
return entry
|
|
63
67
|
raise ValueError(
|
|
64
68
|
f'Could not find uid {uid} in {registry.name} registry, did you register it before the app was initialized?'
|
|
65
|
-
)
|
|
69
|
+
) from e
|
dara/core/internal/routing.py
CHANGED
|
@@ -18,9 +18,10 @@ limitations under the License.
|
|
|
18
18
|
import inspect
|
|
19
19
|
import json
|
|
20
20
|
import os
|
|
21
|
+
from collections.abc import Mapping
|
|
21
22
|
from functools import wraps
|
|
22
23
|
from importlib.metadata import version
|
|
23
|
-
from typing import Any, Callable, Dict, List,
|
|
24
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
24
25
|
|
|
25
26
|
import anyio
|
|
26
27
|
import pandas
|
|
@@ -57,6 +58,7 @@ from dara.core.internal.registries import (
|
|
|
57
58
|
component_registry,
|
|
58
59
|
data_variable_registry,
|
|
59
60
|
derived_variable_registry,
|
|
61
|
+
download_code_registry,
|
|
60
62
|
latest_value_registry,
|
|
61
63
|
static_kwargs_registry,
|
|
62
64
|
template_registry,
|
|
@@ -86,7 +88,7 @@ def error_decorator(handler: Callable[..., Any]):
|
|
|
86
88
|
if isinstance(err, HTTPException):
|
|
87
89
|
raise err
|
|
88
90
|
dev_logger.error('Unhandled error', error=err)
|
|
89
|
-
raise HTTPException(status_code=500, detail=str(err))
|
|
91
|
+
raise HTTPException(status_code=500, detail=str(err)) from err
|
|
90
92
|
|
|
91
93
|
return _async_inner_func
|
|
92
94
|
|
|
@@ -99,7 +101,7 @@ def error_decorator(handler: Callable[..., Any]):
|
|
|
99
101
|
if isinstance(err, HTTPException):
|
|
100
102
|
raise err
|
|
101
103
|
dev_logger.error('Unhandled error', error=err)
|
|
102
|
-
raise HTTPException(status_code=500, detail=str(err))
|
|
104
|
+
raise HTTPException(status_code=500, detail=str(err)) from err
|
|
103
105
|
|
|
104
106
|
return _inner_func
|
|
105
107
|
|
|
@@ -113,7 +115,7 @@ def create_router(config: Configuration):
|
|
|
113
115
|
core_api_router = APIRouter()
|
|
114
116
|
|
|
115
117
|
@core_api_router.get('/actions', dependencies=[Depends(verify_session)])
|
|
116
|
-
async def get_actions():
|
|
118
|
+
async def get_actions():
|
|
117
119
|
return action_def_registry.get_all().items()
|
|
118
120
|
|
|
119
121
|
class ActionRequestBody(BaseModel):
|
|
@@ -133,7 +135,7 @@ def create_router(config: Configuration):
|
|
|
133
135
|
"""Execution id, unique to this request"""
|
|
134
136
|
|
|
135
137
|
@core_api_router.post('/action/{uid}', dependencies=[Depends(verify_session)])
|
|
136
|
-
async def get_action(uid: str, body: ActionRequestBody):
|
|
138
|
+
async def get_action(uid: str, body: ActionRequestBody):
|
|
137
139
|
store: CacheStore = utils_registry.get('Store')
|
|
138
140
|
task_mgr: TaskManager = utils_registry.get('TaskManager')
|
|
139
141
|
registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
|
|
@@ -150,7 +152,14 @@ def create_router(config: Configuration):
|
|
|
150
152
|
|
|
151
153
|
# Execute the action - kick off a background task to run the handler
|
|
152
154
|
response = await action_def.execute_action(
|
|
153
|
-
action_def,
|
|
155
|
+
action_def,
|
|
156
|
+
body.input,
|
|
157
|
+
values,
|
|
158
|
+
static_kwargs,
|
|
159
|
+
body.execution_id,
|
|
160
|
+
body.ws_channel,
|
|
161
|
+
store,
|
|
162
|
+
task_mgr,
|
|
154
163
|
)
|
|
155
164
|
|
|
156
165
|
if isinstance(response, BaseTask):
|
|
@@ -159,15 +168,22 @@ def create_router(config: Configuration):
|
|
|
159
168
|
|
|
160
169
|
return {'execution_id': response}
|
|
161
170
|
|
|
162
|
-
@core_api_router.get('/download')
|
|
163
|
-
async def get_download(code: str):
|
|
171
|
+
@core_api_router.get('/download') # explicitly unauthenticated
|
|
172
|
+
async def get_download(code: str):
|
|
164
173
|
store: CacheStore = utils_registry.get('Store')
|
|
165
174
|
|
|
166
175
|
try:
|
|
167
176
|
data_entry = await store.get(DownloadRegistryEntry, key=code)
|
|
168
177
|
|
|
178
|
+
# If not found directly in the store, use the override registry
|
|
179
|
+
# to check if we can get the download entry from there
|
|
169
180
|
if data_entry is None:
|
|
170
|
-
|
|
181
|
+
registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
|
|
182
|
+
# NOTE: This will throw a Value/KeyError if the code is not found so no need to rethrow
|
|
183
|
+
data_entry = await registry_mgr.get(download_code_registry, code)
|
|
184
|
+
# We managed to find one from the lookup,
|
|
185
|
+
# remove it from the registry immediately because it's one time use
|
|
186
|
+
download_code_registry.remove(code)
|
|
171
187
|
|
|
172
188
|
async_file, cleanup = await data_entry.download(data_entry)
|
|
173
189
|
|
|
@@ -188,11 +204,11 @@ def create_router(config: Configuration):
|
|
|
188
204
|
background=BackgroundTask(cleanup),
|
|
189
205
|
)
|
|
190
206
|
|
|
191
|
-
except KeyError:
|
|
192
|
-
raise ValueError('Invalid or expired download code')
|
|
207
|
+
except (KeyError, ValueError) as e:
|
|
208
|
+
raise ValueError('Invalid or expired download code') from e
|
|
193
209
|
|
|
194
210
|
@core_api_router.get('/config', dependencies=[Depends(verify_session)])
|
|
195
|
-
async def get_config():
|
|
211
|
+
async def get_config():
|
|
196
212
|
return {
|
|
197
213
|
**config.model_dump(
|
|
198
214
|
include={
|
|
@@ -209,13 +225,13 @@ def create_router(config: Configuration):
|
|
|
209
225
|
}
|
|
210
226
|
|
|
211
227
|
@core_api_router.get('/auth-config')
|
|
212
|
-
async def get_auth_config():
|
|
228
|
+
async def get_auth_config():
|
|
213
229
|
return {
|
|
214
230
|
'auth_components': config.auth_config.component_config.model_dump(),
|
|
215
231
|
}
|
|
216
232
|
|
|
217
233
|
@core_api_router.get('/components', dependencies=[Depends(verify_session)])
|
|
218
|
-
async def get_components(name: Optional[str] = None):
|
|
234
|
+
async def get_components(name: Optional[str] = None):
|
|
219
235
|
"""
|
|
220
236
|
If name is passed, will try to register the component
|
|
221
237
|
|
|
@@ -236,7 +252,7 @@ def create_router(config: Configuration):
|
|
|
236
252
|
ws_channel: str
|
|
237
253
|
|
|
238
254
|
@core_api_router.post('/components/{component}', dependencies=[Depends(verify_session)])
|
|
239
|
-
async def get_component(component: str, body: ComponentRequestBody):
|
|
255
|
+
async def get_component(component: str, body: ComponentRequestBody):
|
|
240
256
|
CURRENT_COMPONENT_ID.set(body.uid)
|
|
241
257
|
WS_CHANNEL.set(body.ws_channel)
|
|
242
258
|
store: CacheStore = utils_registry.get('Store')
|
|
@@ -265,7 +281,7 @@ def create_router(config: Configuration):
|
|
|
265
281
|
raise HTTPException(status_code=400, detail='Requesting this type of component is not supported')
|
|
266
282
|
|
|
267
283
|
@core_api_router.get('/derived-variable/{uid}/latest', dependencies=[Depends(verify_session)])
|
|
268
|
-
async def get_latest_derived_variable(uid: str):
|
|
284
|
+
async def get_latest_derived_variable(uid: str):
|
|
269
285
|
try:
|
|
270
286
|
store: CacheStore = utils_registry.get('Store')
|
|
271
287
|
latest_value_entry = latest_value_registry.get(uid)
|
|
@@ -289,9 +305,7 @@ def create_router(config: Configuration):
|
|
|
289
305
|
return latest_value
|
|
290
306
|
|
|
291
307
|
except KeyError as err:
|
|
292
|
-
raise ValueError(f'Could not find latest value for derived variable with uid: {uid}')
|
|
293
|
-
err.__traceback__
|
|
294
|
-
)
|
|
308
|
+
raise ValueError(f'Could not find latest value for derived variable with uid: {uid}') from err
|
|
295
309
|
|
|
296
310
|
class DataVariableRequestBody(BaseModel):
|
|
297
311
|
filters: Optional[FilterQuery] = None
|
|
@@ -306,7 +320,7 @@ def create_router(config: Configuration):
|
|
|
306
320
|
limit: Optional[int] = None,
|
|
307
321
|
order_by: Optional[str] = None,
|
|
308
322
|
index: Optional[str] = None,
|
|
309
|
-
):
|
|
323
|
+
):
|
|
310
324
|
try:
|
|
311
325
|
store: CacheStore = utils_registry.get('Store')
|
|
312
326
|
task_mgr: TaskManager = utils_registry.get('TaskManager')
|
|
@@ -318,11 +332,15 @@ def create_router(config: Configuration):
|
|
|
318
332
|
|
|
319
333
|
if data_variable_entry.type == 'derived':
|
|
320
334
|
if body.cache_key is None:
|
|
321
|
-
raise HTTPException(
|
|
335
|
+
raise HTTPException(
|
|
336
|
+
status_code=400,
|
|
337
|
+
detail='Cache key is required for derived data variables',
|
|
338
|
+
)
|
|
322
339
|
|
|
323
340
|
if body.ws_channel is None:
|
|
324
341
|
raise HTTPException(
|
|
325
|
-
status_code=400,
|
|
342
|
+
status_code=400,
|
|
343
|
+
detail='Websocket channel is required for derived data variables',
|
|
326
344
|
)
|
|
327
345
|
|
|
328
346
|
derived_variable_entry = await registry_mgr.get(derived_variable_registry, uid)
|
|
@@ -351,7 +369,10 @@ def create_router(config: Configuration):
|
|
|
351
369
|
dev_logger.debug(
|
|
352
370
|
f'DataVariable {data_variable_entry.uid[:3]}..{data_variable_entry.uid[-3:]}',
|
|
353
371
|
'return value',
|
|
354
|
-
{
|
|
372
|
+
{
|
|
373
|
+
'value': data.describe() if isinstance(data, pandas.DataFrame) else None,
|
|
374
|
+
'uid': uid,
|
|
375
|
+
}, # type: ignore
|
|
355
376
|
)
|
|
356
377
|
|
|
357
378
|
if data is None:
|
|
@@ -360,10 +381,11 @@ def create_router(config: Configuration):
|
|
|
360
381
|
# Explicitly convert to JSON to avoid implicit serialization;
|
|
361
382
|
# return as records as that makes more sense in a JSON structure
|
|
362
383
|
return Response(
|
|
363
|
-
content=df_to_json(data) if isinstance(data, pandas.DataFrame) else data,
|
|
364
|
-
|
|
384
|
+
content=df_to_json(data) if isinstance(data, pandas.DataFrame) else data,
|
|
385
|
+
media_type='application/json',
|
|
386
|
+
) # type: ignore
|
|
365
387
|
except ValueError as e:
|
|
366
|
-
raise HTTPException(status_code=400, detail=str(e))
|
|
388
|
+
raise HTTPException(status_code=400, detail=str(e)) from e
|
|
367
389
|
|
|
368
390
|
class DataVariableCountRequestBody(BaseModel):
|
|
369
391
|
cache_key: Optional[str] = None
|
|
@@ -383,12 +405,13 @@ def create_router(config: Configuration):
|
|
|
383
405
|
|
|
384
406
|
if body is None or body.cache_key is None:
|
|
385
407
|
raise HTTPException(
|
|
386
|
-
status_code=400,
|
|
408
|
+
status_code=400,
|
|
409
|
+
detail="Cache key is required when requesting DerivedDataVariable's count",
|
|
387
410
|
)
|
|
388
411
|
|
|
389
412
|
return await variable_def.get_total_count(variable_def, store, body.cache_key, body.filters)
|
|
390
413
|
except ValueError as e:
|
|
391
|
-
raise HTTPException(status_code=400, detail=str(e))
|
|
414
|
+
raise HTTPException(status_code=400, detail=str(e)) from e
|
|
392
415
|
|
|
393
416
|
@core_api_router.get('/data-variable/{uid}/schema', dependencies=[Depends(verify_session)])
|
|
394
417
|
async def get_data_variable_schema(uid: str, cache_key: Optional[str] = None):
|
|
@@ -402,7 +425,8 @@ def create_router(config: Configuration):
|
|
|
402
425
|
|
|
403
426
|
if cache_key is None:
|
|
404
427
|
raise HTTPException(
|
|
405
|
-
status_code=400,
|
|
428
|
+
status_code=400,
|
|
429
|
+
detail='Cache key is required when requesting DerivedDataVariable schema',
|
|
406
430
|
)
|
|
407
431
|
|
|
408
432
|
# Use the other registry for derived variables
|
|
@@ -411,7 +435,7 @@ def create_router(config: Configuration):
|
|
|
411
435
|
content = json.dumps(jsonable_encoder(data)) if isinstance(data, dict) else data
|
|
412
436
|
return Response(content=content, media_type='application/json')
|
|
413
437
|
except ValueError as e:
|
|
414
|
-
raise HTTPException(status_code=400, detail=str(e))
|
|
438
|
+
raise HTTPException(status_code=400, detail=str(e)) from e
|
|
415
439
|
|
|
416
440
|
@core_api_router.post('/data/upload', dependencies=[Depends(verify_session)])
|
|
417
441
|
async def upload_data(
|
|
@@ -427,7 +451,10 @@ def create_router(config: Configuration):
|
|
|
427
451
|
registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
|
|
428
452
|
|
|
429
453
|
if data_uid is None and resolver_id is None:
|
|
430
|
-
raise HTTPException(
|
|
454
|
+
raise HTTPException(
|
|
455
|
+
400,
|
|
456
|
+
'Neither resolver_id or data_uid specified, at least one of them is required',
|
|
457
|
+
)
|
|
431
458
|
|
|
432
459
|
try:
|
|
433
460
|
# If resolver id is provided, run the custom
|
|
@@ -440,16 +467,16 @@ def create_router(config: Configuration):
|
|
|
440
467
|
|
|
441
468
|
return {'status': 'SUCCESS'}
|
|
442
469
|
except Exception as e:
|
|
443
|
-
raise HTTPException(status_code=400, detail=str(e))
|
|
470
|
+
raise HTTPException(status_code=400, detail=str(e)) from e
|
|
444
471
|
|
|
445
472
|
class DerivedStateRequestBody(BaseModel):
|
|
446
473
|
values: NormalizedPayload[List[Any]]
|
|
447
|
-
|
|
474
|
+
force_key: Optional[str] = None
|
|
448
475
|
ws_channel: str
|
|
449
476
|
is_data_variable: Optional[bool] = False
|
|
450
477
|
|
|
451
478
|
@core_api_router.post('/derived-variable/{uid}', dependencies=[Depends(verify_session)])
|
|
452
|
-
async def get_derived_variable(uid: str, body: DerivedStateRequestBody):
|
|
479
|
+
async def get_derived_variable(uid: str, body: DerivedStateRequestBody):
|
|
453
480
|
task_mgr: TaskManager = utils_registry.get('TaskManager')
|
|
454
481
|
store: CacheStore = utils_registry.get('Store')
|
|
455
482
|
registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
|
|
@@ -457,7 +484,7 @@ def create_router(config: Configuration):
|
|
|
457
484
|
|
|
458
485
|
values = denormalize(body.values.data, body.values.lookup)
|
|
459
486
|
|
|
460
|
-
result = await variable_def.get_value(variable_def, store, task_mgr, values, body.
|
|
487
|
+
result = await variable_def.get_value(variable_def, store, task_mgr, values, body.force_key)
|
|
461
488
|
|
|
462
489
|
response: Any = result
|
|
463
490
|
|
|
@@ -466,7 +493,10 @@ def create_router(config: Configuration):
|
|
|
466
493
|
if isinstance(result['value'], BaseTask):
|
|
467
494
|
# Kick off the task
|
|
468
495
|
await task_mgr.run_task(result['value'], body.ws_channel)
|
|
469
|
-
response = {
|
|
496
|
+
response = {
|
|
497
|
+
'task_id': result['value'].task_id,
|
|
498
|
+
'cache_key': result['cache_key'],
|
|
499
|
+
}
|
|
470
500
|
|
|
471
501
|
dev_logger.debug(
|
|
472
502
|
f'DerivedVariable {variable_def.uid[:3]}..{variable_def.uid[-3:]}',
|
|
@@ -512,7 +542,7 @@ def create_router(config: Configuration):
|
|
|
512
542
|
tg.start_soon(_write, store_uid, value)
|
|
513
543
|
|
|
514
544
|
@core_api_router.get('/tasks/{task_id}', dependencies=[Depends(verify_session)])
|
|
515
|
-
async def get_task_result(task_id: str):
|
|
545
|
+
async def get_task_result(task_id: str):
|
|
516
546
|
try:
|
|
517
547
|
task_mgr: TaskManager = utils_registry.get('TaskManager')
|
|
518
548
|
res = await task_mgr.get_result(task_id)
|
|
@@ -528,25 +558,28 @@ def create_router(config: Configuration):
|
|
|
528
558
|
return Response(df_to_json(res))
|
|
529
559
|
|
|
530
560
|
return res
|
|
531
|
-
except Exception as
|
|
532
|
-
raise ValueError(f'The result for task id {task_id} could not be found')
|
|
561
|
+
except Exception as err:
|
|
562
|
+
raise ValueError(f'The result for task id {task_id} could not be found') from err
|
|
533
563
|
|
|
534
564
|
@core_api_router.delete('/tasks/{task_id}', dependencies=[Depends(verify_session)])
|
|
535
|
-
async def cancel_task(task_id: str):
|
|
565
|
+
async def cancel_task(task_id: str):
|
|
536
566
|
try:
|
|
537
567
|
task_mgr: TaskManager = utils_registry.get('TaskManager')
|
|
538
568
|
return await task_mgr.cancel_task(task_id)
|
|
539
569
|
except TaskManagerError as e:
|
|
540
|
-
dev_logger.error(
|
|
570
|
+
dev_logger.error(
|
|
571
|
+
f'The task id {task_id} could not be found, it may have already been cancelled',
|
|
572
|
+
e,
|
|
573
|
+
)
|
|
541
574
|
|
|
542
575
|
@core_api_router.get('/template/{template}', dependencies=[Depends(verify_session)])
|
|
543
|
-
async def get_template(template: str):
|
|
576
|
+
async def get_template(template: str):
|
|
544
577
|
try:
|
|
545
578
|
selected_template = template_registry.get(template)
|
|
546
579
|
normalized_template, lookup = normalize(jsonable_encoder(selected_template))
|
|
547
580
|
return {'data': normalized_template, 'lookup': lookup}
|
|
548
|
-
except KeyError:
|
|
549
|
-
raise HTTPException(status_code=404, detail=f'Template: {template}, not found in registry')
|
|
581
|
+
except KeyError as err:
|
|
582
|
+
raise HTTPException(status_code=404, detail=f'Template: {template}, not found in registry') from err
|
|
550
583
|
except Exception as e:
|
|
551
584
|
dev_logger.error('Something went wrong while trying to get the template', e)
|
|
552
585
|
|
dara/core/internal/scheduler.py
CHANGED
|
@@ -20,7 +20,7 @@ from datetime import datetime
|
|
|
20
20
|
from multiprocessing import get_context
|
|
21
21
|
from multiprocessing.process import BaseProcess
|
|
22
22
|
from pickle import PicklingError
|
|
23
|
-
from typing import Any, List, Optional, Union
|
|
23
|
+
from typing import Any, List, Optional, Union, cast
|
|
24
24
|
|
|
25
25
|
from croniter import croniter
|
|
26
26
|
from pydantic import BaseModel, field_validator
|
|
@@ -56,25 +56,28 @@ class ScheduledJob(BaseModel):
|
|
|
56
56
|
job_process = ctx.Process(target=self._refresh_timer, args=(func, args), daemon=True)
|
|
57
57
|
job_process.start()
|
|
58
58
|
return job_process
|
|
59
|
-
except PicklingError:
|
|
59
|
+
except PicklingError as err:
|
|
60
60
|
raise PicklingError(
|
|
61
61
|
"""
|
|
62
62
|
Unable to pickle scheduled function. Please ensure that the function you are trying
|
|
63
63
|
to schedule is not in the same file as the ConfigurationBuilder is defined and that
|
|
64
64
|
the function is not a lambda.
|
|
65
65
|
"""
|
|
66
|
-
)
|
|
66
|
+
) from err
|
|
67
67
|
|
|
68
68
|
def _refresh_timer(self, func, args):
|
|
69
69
|
while self.continue_running and not (self.run_once and not self.first_execution):
|
|
70
|
-
interval
|
|
70
|
+
interval: int
|
|
71
71
|
# If there's more than one interval to wait, i.e. this is a weekday process
|
|
72
|
-
if
|
|
72
|
+
if isinstance(self.interval, list):
|
|
73
73
|
# Wait the first interval if this is the first execution of the job
|
|
74
74
|
interval = self.interval[0] if self.first_execution else self.interval[1]
|
|
75
|
+
else:
|
|
76
|
+
interval = self.interval
|
|
77
|
+
|
|
75
78
|
self.first_execution = False
|
|
76
79
|
# Wait the interval and then run the job
|
|
77
|
-
time.sleep(interval)
|
|
80
|
+
time.sleep(cast(int, interval))
|
|
78
81
|
func(*args)
|
|
79
82
|
|
|
80
83
|
|
|
@@ -176,7 +179,7 @@ class ScheduledJobFactory(BaseModel):
|
|
|
176
179
|
|
|
177
180
|
@field_validator('weekday', mode='before')
|
|
178
181
|
@classmethod
|
|
179
|
-
def validate_weekday(cls, weekday: Any) -> datetime:
|
|
182
|
+
def validate_weekday(cls, weekday: Any) -> datetime:
|
|
180
183
|
if isinstance(weekday, datetime):
|
|
181
184
|
return weekday
|
|
182
185
|
if isinstance(weekday, str):
|
|
@@ -283,7 +286,9 @@ class Scheduler:
|
|
|
283
286
|
def _weekday(self, weekday: int):
|
|
284
287
|
# The job must execute on a weekly interval
|
|
285
288
|
return ScheduledJobFactory(
|
|
286
|
-
interval=self.interval * 604800,
|
|
289
|
+
interval=self.interval * 604800,
|
|
290
|
+
run_once=self._run_once,
|
|
291
|
+
weekday=str(weekday), # type: ignore
|
|
287
292
|
)
|
|
288
293
|
|
|
289
294
|
def monday(self) -> ScheduledJobFactory:
|
dara/core/internal/settings.py
CHANGED
|
@@ -74,7 +74,7 @@ def generate_env_file(filename='.env'):
|
|
|
74
74
|
f.write(env_content)
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
@lru_cache
|
|
77
|
+
@lru_cache
|
|
78
78
|
def get_settings():
|
|
79
79
|
"""
|
|
80
80
|
Get a cached instance of the settings, loading values from the .env if present.
|
|
@@ -83,7 +83,7 @@ def get_settings():
|
|
|
83
83
|
|
|
84
84
|
# Test purposes - if DARA_TEST_FLAG is set then override env with .env.test
|
|
85
85
|
if os.environ.get('DARA_TEST_FLAG', None) is not None:
|
|
86
|
-
return Settings(**dotenv_values('.env.test'))
|
|
86
|
+
return Settings(**dotenv_values('.env.test')) # type: ignore
|
|
87
87
|
|
|
88
88
|
env_error = False
|
|
89
89
|
|
dara/core/internal/tasks.py
CHANGED
|
@@ -15,9 +15,11 @@ See the License for the specific language governing permissions and
|
|
|
15
15
|
limitations under the License.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
|
+
import contextlib
|
|
18
19
|
import inspect
|
|
19
20
|
import math
|
|
20
|
-
from
|
|
21
|
+
from collections.abc import Awaitable
|
|
22
|
+
from typing import Any, Callable, Dict, List, Optional, Union, overload
|
|
21
23
|
|
|
22
24
|
from anyio import (
|
|
23
25
|
CancelScope,
|
|
@@ -137,14 +139,12 @@ class Task(BaseTask):
|
|
|
137
139
|
|
|
138
140
|
async def on_progress(progress: float, msg: str):
|
|
139
141
|
if send_stream is not None:
|
|
140
|
-
|
|
142
|
+
with contextlib.suppress(ClosedResourceError):
|
|
141
143
|
await send_stream.send(TaskProgressUpdate(task_id=self.task_id, progress=progress, message=msg))
|
|
142
|
-
except ClosedResourceError:
|
|
143
|
-
pass
|
|
144
144
|
|
|
145
145
|
async def on_result(result: Any):
|
|
146
146
|
if send_stream is not None:
|
|
147
|
-
|
|
147
|
+
with contextlib.suppress(ClosedResourceError):
|
|
148
148
|
await send_stream.send(
|
|
149
149
|
TaskResult(
|
|
150
150
|
task_id=self.task_id,
|
|
@@ -153,19 +153,15 @@ class Task(BaseTask):
|
|
|
153
153
|
reg_entry=self.reg_entry,
|
|
154
154
|
)
|
|
155
155
|
)
|
|
156
|
-
except ClosedResourceError:
|
|
157
|
-
pass
|
|
158
156
|
|
|
159
157
|
async def on_error(exc: BaseException):
|
|
160
158
|
if send_stream is not None:
|
|
161
|
-
|
|
159
|
+
with contextlib.suppress(ClosedResourceError):
|
|
162
160
|
await send_stream.send(
|
|
163
161
|
TaskError(
|
|
164
162
|
task_id=self.task_id, error=exc, cache_key=self.cache_key, reg_entry=self.reg_entry
|
|
165
163
|
)
|
|
166
164
|
)
|
|
167
|
-
except ClosedResourceError:
|
|
168
|
-
pass
|
|
169
165
|
|
|
170
166
|
with pool.on_progress(self.task_id, on_progress):
|
|
171
167
|
pool_task_def = pool.submit(self.task_id, self._func_name, args=tuple(self._args), kwargs=self._kwargs)
|
|
@@ -362,12 +358,10 @@ class TaskManager:
|
|
|
362
358
|
self.store = store
|
|
363
359
|
|
|
364
360
|
@overload
|
|
365
|
-
async def run_task(self, task: PendingTask, ws_channel: Optional[str] = None) -> Any:
|
|
366
|
-
...
|
|
361
|
+
async def run_task(self, task: PendingTask, ws_channel: Optional[str] = None) -> Any: ...
|
|
367
362
|
|
|
368
363
|
@overload
|
|
369
|
-
async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None) -> PendingTask:
|
|
370
|
-
...
|
|
364
|
+
async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None) -> PendingTask: ...
|
|
371
365
|
|
|
372
366
|
async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None):
|
|
373
367
|
"""
|
dara/core/internal/utils.py
CHANGED
|
@@ -20,6 +20,7 @@ from __future__ import annotations
|
|
|
20
20
|
import asyncio
|
|
21
21
|
import inspect
|
|
22
22
|
import os
|
|
23
|
+
from collections.abc import Awaitable, Coroutine, Sequence
|
|
23
24
|
from functools import wraps
|
|
24
25
|
from importlib import import_module
|
|
25
26
|
from importlib.util import find_spec
|
|
@@ -27,13 +28,10 @@ from types import ModuleType
|
|
|
27
28
|
from typing import (
|
|
28
29
|
TYPE_CHECKING,
|
|
29
30
|
Any,
|
|
30
|
-
Awaitable,
|
|
31
31
|
Callable,
|
|
32
|
-
Coroutine,
|
|
33
32
|
Dict,
|
|
34
33
|
Literal,
|
|
35
34
|
Optional,
|
|
36
|
-
Sequence,
|
|
37
35
|
Tuple,
|
|
38
36
|
Union,
|
|
39
37
|
)
|
|
@@ -79,7 +77,7 @@ def get_cache_scope(cache_type: Optional[CacheType]) -> CacheScope:
|
|
|
79
77
|
return 'global'
|
|
80
78
|
|
|
81
79
|
|
|
82
|
-
async def run_user_handler(handler: Callable, args: Sequence =
|
|
80
|
+
async def run_user_handler(handler: Callable, args: Union[Sequence, None] = None, kwargs: Union[dict, None] = None):
|
|
83
81
|
"""
|
|
84
82
|
Run a user-defined handler function. Runs sync functions in a threadpool.
|
|
85
83
|
Handles SystemExits cleanly.
|
|
@@ -88,6 +86,10 @@ async def run_user_handler(handler: Callable, args: Sequence = [], kwargs: dict
|
|
|
88
86
|
:param args: list of arguments to pass to the function
|
|
89
87
|
:param kwargs: dict of kwargs to past to the function
|
|
90
88
|
"""
|
|
89
|
+
if args is None:
|
|
90
|
+
args = []
|
|
91
|
+
if kwargs is None:
|
|
92
|
+
kwargs = {}
|
|
91
93
|
with handle_system_exit('User defined function quit unexpectedly'):
|
|
92
94
|
if inspect.iscoroutinefunction(handler):
|
|
93
95
|
return await handler(*args, **kwargs)
|
|
@@ -164,14 +166,14 @@ def enforce_sso(conf: ConfigurationBuilder):
|
|
|
164
166
|
Raises if SSO is not used
|
|
165
167
|
"""
|
|
166
168
|
try:
|
|
167
|
-
from dara.enterprise import SSOAuthConfig
|
|
169
|
+
from dara.enterprise import SSOAuthConfig # pyright: ignore[reportMissingImports]
|
|
168
170
|
|
|
169
171
|
if conf.auth_config is None or not isinstance(conf.auth_config, SSOAuthConfig):
|
|
170
172
|
raise ValueError('Config does not have SSO auth enabled. Please update your application to configure SSO.')
|
|
171
|
-
except ImportError:
|
|
173
|
+
except ImportError as err:
|
|
172
174
|
raise ValueError(
|
|
173
175
|
'SSO is not enabled. Please install the dara_enterprise package and configure SSO to use this feature.'
|
|
174
|
-
)
|
|
176
|
+
) from err
|
|
175
177
|
|
|
176
178
|
|
|
177
179
|
def async_dedupe(fn: Callable[..., Awaitable]):
|
|
@@ -228,8 +230,7 @@ def resolve_exception_group(error: Any):
|
|
|
228
230
|
|
|
229
231
|
:param error: The error to resolve
|
|
230
232
|
"""
|
|
231
|
-
if isinstance(error, ExceptionGroup):
|
|
232
|
-
|
|
233
|
-
return resolve_exception_group(error.exceptions[0])
|
|
233
|
+
if isinstance(error, ExceptionGroup) and len(error.exceptions) == 1:
|
|
234
|
+
return resolve_exception_group(error.exceptions[0])
|
|
234
235
|
|
|
235
236
|
return error
|
dara/core/internal/websocket.py
CHANGED
|
@@ -204,28 +204,27 @@ class WebSocketHandler:
|
|
|
204
204
|
message_id = message.channel
|
|
205
205
|
|
|
206
206
|
# If the message has a channel ID, it's a response to a previous message
|
|
207
|
-
if message_id:
|
|
208
|
-
|
|
209
|
-
event, existing_messages = self.pending_responses[message_id]
|
|
210
|
-
|
|
211
|
-
# If the response is chunked then collect the messages in pending responses
|
|
212
|
-
if message.chunk_count is not None:
|
|
213
|
-
if existing_messages is not None and isinstance(existing_messages, list):
|
|
214
|
-
existing_messages.append(message.message)
|
|
215
|
-
else:
|
|
216
|
-
existing_messages = [message.message]
|
|
217
|
-
self.pending_responses[message_id] = (
|
|
218
|
-
event,
|
|
219
|
-
existing_messages,
|
|
220
|
-
)
|
|
207
|
+
if message_id and message_id in self.pending_responses:
|
|
208
|
+
event, existing_messages = self.pending_responses[message_id]
|
|
221
209
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
210
|
+
# If the response is chunked then collect the messages in pending responses
|
|
211
|
+
if message.chunk_count is not None:
|
|
212
|
+
if existing_messages is not None and isinstance(existing_messages, list):
|
|
213
|
+
existing_messages.append(message.message)
|
|
225
214
|
else:
|
|
226
|
-
|
|
227
|
-
self.pending_responses[message_id] = (
|
|
215
|
+
existing_messages = [message.message]
|
|
216
|
+
self.pending_responses[message_id] = (
|
|
217
|
+
event,
|
|
218
|
+
existing_messages,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# If all chunks have been received, set the event to notify the waiting coroutine
|
|
222
|
+
if len(existing_messages) == message.chunk_count:
|
|
228
223
|
event.set()
|
|
224
|
+
else:
|
|
225
|
+
# Store the response and set the event to notify the waiting coroutine
|
|
226
|
+
self.pending_responses[message_id] = (event, message.message)
|
|
227
|
+
event.set()
|
|
229
228
|
|
|
230
229
|
return None
|
|
231
230
|
|