dara-core 1.17.6__py3-none-any.whl → 1.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. dara/core/__init__.py +2 -0
  2. dara/core/actions.py +1 -2
  3. dara/core/auth/basic.py +9 -9
  4. dara/core/auth/routes.py +5 -5
  5. dara/core/auth/utils.py +4 -4
  6. dara/core/base_definitions.py +15 -22
  7. dara/core/cli.py +8 -7
  8. dara/core/configuration.py +5 -2
  9. dara/core/css.py +1 -2
  10. dara/core/data_utils.py +2 -2
  11. dara/core/defaults.py +4 -7
  12. dara/core/definitions.py +6 -9
  13. dara/core/http.py +7 -3
  14. dara/core/interactivity/actions.py +28 -30
  15. dara/core/interactivity/any_data_variable.py +6 -5
  16. dara/core/interactivity/any_variable.py +4 -7
  17. dara/core/interactivity/data_variable.py +1 -1
  18. dara/core/interactivity/derived_data_variable.py +7 -6
  19. dara/core/interactivity/derived_variable.py +93 -33
  20. dara/core/interactivity/filtering.py +19 -27
  21. dara/core/interactivity/plain_variable.py +3 -2
  22. dara/core/interactivity/switch_variable.py +4 -4
  23. dara/core/internal/cache_store/base_impl.py +2 -1
  24. dara/core/internal/cache_store/cache_store.py +17 -5
  25. dara/core/internal/cache_store/keep_all.py +4 -1
  26. dara/core/internal/cache_store/lru.py +5 -1
  27. dara/core/internal/cache_store/ttl.py +4 -1
  28. dara/core/internal/cgroup.py +1 -1
  29. dara/core/internal/dependency_resolution.py +46 -10
  30. dara/core/internal/devtools.py +2 -2
  31. dara/core/internal/download.py +4 -3
  32. dara/core/internal/encoder_registry.py +7 -7
  33. dara/core/internal/execute_action.py +4 -10
  34. dara/core/internal/hashing.py +1 -3
  35. dara/core/internal/import_discovery.py +3 -4
  36. dara/core/internal/normalization.py +9 -13
  37. dara/core/internal/pandas_utils.py +3 -3
  38. dara/core/internal/pool/task_pool.py +16 -10
  39. dara/core/internal/pool/utils.py +5 -7
  40. dara/core/internal/pool/worker.py +3 -2
  41. dara/core/internal/port_utils.py +1 -1
  42. dara/core/internal/registries.py +9 -4
  43. dara/core/internal/registry.py +3 -1
  44. dara/core/internal/registry_lookup.py +7 -3
  45. dara/core/internal/routing.py +77 -44
  46. dara/core/internal/scheduler.py +13 -8
  47. dara/core/internal/settings.py +2 -2
  48. dara/core/internal/tasks.py +8 -14
  49. dara/core/internal/utils.py +11 -10
  50. dara/core/internal/websocket.py +18 -19
  51. dara/core/js_tooling/js_utils.py +23 -24
  52. dara/core/logging.py +3 -6
  53. dara/core/main.py +14 -11
  54. dara/core/metrics/cache.py +1 -1
  55. dara/core/metrics/utils.py +3 -3
  56. dara/core/persistence.py +1 -1
  57. dara/core/umd/dara.core.umd.js +146 -128
  58. dara/core/visual/components/__init__.py +2 -2
  59. dara/core/visual/components/fallback.py +3 -3
  60. dara/core/visual/css/__init__.py +30 -31
  61. dara/core/visual/dynamic_component.py +10 -11
  62. dara/core/visual/progress_updater.py +4 -3
  63. {dara_core-1.17.6.dist-info → dara_core-1.18.0.dist-info}/METADATA +10 -10
  64. dara_core-1.18.0.dist-info/RECORD +114 -0
  65. dara_core-1.17.6.dist-info/RECORD +0 -114
  66. {dara_core-1.17.6.dist-info → dara_core-1.18.0.dist-info}/LICENSE +0 -0
  67. {dara_core-1.17.6.dist-info → dara_core-1.18.0.dist-info}/WHEEL +0 -0
  68. {dara_core-1.17.6.dist-info → dara_core-1.18.0.dist-info}/entry_points.txt +0 -0
@@ -15,7 +15,8 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
- from typing import Callable, Coroutine, Dict, Literal
18
+ from collections.abc import Coroutine
19
+ from typing import Callable, Dict, Literal, Union
19
20
 
20
21
  from dara.core.internal.registry import Registry, RegistryType
21
22
  from dara.core.internal.utils import async_dedupe
@@ -28,6 +29,7 @@ RegistryLookupKey = Literal[
28
29
  RegistryType.STATIC_KWARGS,
29
30
  RegistryType.UPLOAD_RESOLVER,
30
31
  RegistryType.BACKEND_STORE,
32
+ RegistryType.DOWNLOAD_CODE,
31
33
  ]
32
34
  CustomRegistryLookup = Dict[RegistryLookupKey, Callable[[str], Coroutine]]
33
35
 
@@ -37,7 +39,9 @@ class RegistryLookup:
37
39
  Manages registry Lookup.
38
40
  """
39
41
 
40
- def __init__(self, handlers: CustomRegistryLookup = {}):
42
+ def __init__(self, handlers: Union[CustomRegistryLookup, None] = None):
43
+ if handlers is None:
44
+ handlers = {}
41
45
  self.handlers = handlers
42
46
 
43
47
  @async_dedupe
@@ -62,4 +66,4 @@ class RegistryLookup:
62
66
  return entry
63
67
  raise ValueError(
64
68
  f'Could not find uid {uid} in {registry.name} registry, did you register it before the app was initialized?'
65
- ).with_traceback(e.__traceback__)
69
+ ) from e
@@ -18,9 +18,10 @@ limitations under the License.
18
18
  import inspect
19
19
  import json
20
20
  import os
21
+ from collections.abc import Mapping
21
22
  from functools import wraps
22
23
  from importlib.metadata import version
23
- from typing import Any, Callable, Dict, List, Mapping, Optional
24
+ from typing import Any, Callable, Dict, List, Optional
24
25
 
25
26
  import anyio
26
27
  import pandas
@@ -57,6 +58,7 @@ from dara.core.internal.registries import (
57
58
  component_registry,
58
59
  data_variable_registry,
59
60
  derived_variable_registry,
61
+ download_code_registry,
60
62
  latest_value_registry,
61
63
  static_kwargs_registry,
62
64
  template_registry,
@@ -86,7 +88,7 @@ def error_decorator(handler: Callable[..., Any]):
86
88
  if isinstance(err, HTTPException):
87
89
  raise err
88
90
  dev_logger.error('Unhandled error', error=err)
89
- raise HTTPException(status_code=500, detail=str(err))
91
+ raise HTTPException(status_code=500, detail=str(err)) from err
90
92
 
91
93
  return _async_inner_func
92
94
 
@@ -99,7 +101,7 @@ def error_decorator(handler: Callable[..., Any]):
99
101
  if isinstance(err, HTTPException):
100
102
  raise err
101
103
  dev_logger.error('Unhandled error', error=err)
102
- raise HTTPException(status_code=500, detail=str(err))
104
+ raise HTTPException(status_code=500, detail=str(err)) from err
103
105
 
104
106
  return _inner_func
105
107
 
@@ -113,7 +115,7 @@ def create_router(config: Configuration):
113
115
  core_api_router = APIRouter()
114
116
 
115
117
  @core_api_router.get('/actions', dependencies=[Depends(verify_session)])
116
- async def get_actions(): # pylint: disable=unused-variable
118
+ async def get_actions():
117
119
  return action_def_registry.get_all().items()
118
120
 
119
121
  class ActionRequestBody(BaseModel):
@@ -133,7 +135,7 @@ def create_router(config: Configuration):
133
135
  """Execution id, unique to this request"""
134
136
 
135
137
  @core_api_router.post('/action/{uid}', dependencies=[Depends(verify_session)])
136
- async def get_action(uid: str, body: ActionRequestBody): # pylint: disable=unused-variable
138
+ async def get_action(uid: str, body: ActionRequestBody):
137
139
  store: CacheStore = utils_registry.get('Store')
138
140
  task_mgr: TaskManager = utils_registry.get('TaskManager')
139
141
  registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
@@ -150,7 +152,14 @@ def create_router(config: Configuration):
150
152
 
151
153
  # Execute the action - kick off a background task to run the handler
152
154
  response = await action_def.execute_action(
153
- action_def, body.input, values, static_kwargs, body.execution_id, body.ws_channel, store, task_mgr
155
+ action_def,
156
+ body.input,
157
+ values,
158
+ static_kwargs,
159
+ body.execution_id,
160
+ body.ws_channel,
161
+ store,
162
+ task_mgr,
154
163
  )
155
164
 
156
165
  if isinstance(response, BaseTask):
@@ -159,15 +168,22 @@ def create_router(config: Configuration):
159
168
 
160
169
  return {'execution_id': response}
161
170
 
162
- @core_api_router.get('/download')
163
- async def get_download(code: str): # pylint: disable=unused-variable
171
+ @core_api_router.get('/download') # explicitly unauthenticated
172
+ async def get_download(code: str):
164
173
  store: CacheStore = utils_registry.get('Store')
165
174
 
166
175
  try:
167
176
  data_entry = await store.get(DownloadRegistryEntry, key=code)
168
177
 
178
+ # If not found directly in the store, use the override registry
179
+ # to check if we can get the download entry from there
169
180
  if data_entry is None:
170
- raise ValueError('Invalid or expired download code')
181
+ registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
182
+ # NOTE: This will throw a Value/KeyError if the code is not found so no need to rethrow
183
+ data_entry = await registry_mgr.get(download_code_registry, code)
184
+ # We managed to find one from the lookup,
185
+ # remove it from the registry immediately because it's one time use
186
+ download_code_registry.remove(code)
171
187
 
172
188
  async_file, cleanup = await data_entry.download(data_entry)
173
189
 
@@ -188,11 +204,11 @@ def create_router(config: Configuration):
188
204
  background=BackgroundTask(cleanup),
189
205
  )
190
206
 
191
- except KeyError:
192
- raise ValueError('Invalid or expired download code')
207
+ except (KeyError, ValueError) as e:
208
+ raise ValueError('Invalid or expired download code') from e
193
209
 
194
210
  @core_api_router.get('/config', dependencies=[Depends(verify_session)])
195
- async def get_config(): # pylint: disable=unused-variable
211
+ async def get_config():
196
212
  return {
197
213
  **config.model_dump(
198
214
  include={
@@ -209,13 +225,13 @@ def create_router(config: Configuration):
209
225
  }
210
226
 
211
227
  @core_api_router.get('/auth-config')
212
- async def get_auth_config(): # pylint: disable=unused-variable
228
+ async def get_auth_config():
213
229
  return {
214
230
  'auth_components': config.auth_config.component_config.model_dump(),
215
231
  }
216
232
 
217
233
  @core_api_router.get('/components', dependencies=[Depends(verify_session)])
218
- async def get_components(name: Optional[str] = None): # pylint: disable=unused-variable
234
+ async def get_components(name: Optional[str] = None):
219
235
  """
220
236
  If name is passed, will try to register the component
221
237
 
@@ -236,7 +252,7 @@ def create_router(config: Configuration):
236
252
  ws_channel: str
237
253
 
238
254
  @core_api_router.post('/components/{component}', dependencies=[Depends(verify_session)])
239
- async def get_component(component: str, body: ComponentRequestBody): # pylint: disable=unused-variable
255
+ async def get_component(component: str, body: ComponentRequestBody):
240
256
  CURRENT_COMPONENT_ID.set(body.uid)
241
257
  WS_CHANNEL.set(body.ws_channel)
242
258
  store: CacheStore = utils_registry.get('Store')
@@ -265,7 +281,7 @@ def create_router(config: Configuration):
265
281
  raise HTTPException(status_code=400, detail='Requesting this type of component is not supported')
266
282
 
267
283
  @core_api_router.get('/derived-variable/{uid}/latest', dependencies=[Depends(verify_session)])
268
- async def get_latest_derived_variable(uid: str): # pylint: disable=unused-variable
284
+ async def get_latest_derived_variable(uid: str):
269
285
  try:
270
286
  store: CacheStore = utils_registry.get('Store')
271
287
  latest_value_entry = latest_value_registry.get(uid)
@@ -289,9 +305,7 @@ def create_router(config: Configuration):
289
305
  return latest_value
290
306
 
291
307
  except KeyError as err:
292
- raise ValueError(f'Could not find latest value for derived variable with uid: {uid}').with_traceback(
293
- err.__traceback__
294
- )
308
+ raise ValueError(f'Could not find latest value for derived variable with uid: {uid}') from err
295
309
 
296
310
  class DataVariableRequestBody(BaseModel):
297
311
  filters: Optional[FilterQuery] = None
@@ -306,7 +320,7 @@ def create_router(config: Configuration):
306
320
  limit: Optional[int] = None,
307
321
  order_by: Optional[str] = None,
308
322
  index: Optional[str] = None,
309
- ): # pylint: disable=unused-variable
323
+ ):
310
324
  try:
311
325
  store: CacheStore = utils_registry.get('Store')
312
326
  task_mgr: TaskManager = utils_registry.get('TaskManager')
@@ -318,11 +332,15 @@ def create_router(config: Configuration):
318
332
 
319
333
  if data_variable_entry.type == 'derived':
320
334
  if body.cache_key is None:
321
- raise HTTPException(status_code=400, detail='Cache key is required for derived data variables')
335
+ raise HTTPException(
336
+ status_code=400,
337
+ detail='Cache key is required for derived data variables',
338
+ )
322
339
 
323
340
  if body.ws_channel is None:
324
341
  raise HTTPException(
325
- status_code=400, detail='Websocket channel is required for derived data variables'
342
+ status_code=400,
343
+ detail='Websocket channel is required for derived data variables',
326
344
  )
327
345
 
328
346
  derived_variable_entry = await registry_mgr.get(derived_variable_registry, uid)
@@ -351,7 +369,10 @@ def create_router(config: Configuration):
351
369
  dev_logger.debug(
352
370
  f'DataVariable {data_variable_entry.uid[:3]}..{data_variable_entry.uid[-3:]}',
353
371
  'return value',
354
- {'value': data.describe() if isinstance(data, pandas.DataFrame) else None, 'uid': uid}, # type: ignore
372
+ {
373
+ 'value': data.describe() if isinstance(data, pandas.DataFrame) else None,
374
+ 'uid': uid,
375
+ }, # type: ignore
355
376
  )
356
377
 
357
378
  if data is None:
@@ -360,10 +381,11 @@ def create_router(config: Configuration):
360
381
  # Explicitly convert to JSON to avoid implicit serialization;
361
382
  # return as records as that makes more sense in a JSON structure
362
383
  return Response(
363
- content=df_to_json(data) if isinstance(data, pandas.DataFrame) else data, media_type='application/json'
364
- ) # type: ignore
384
+ content=df_to_json(data) if isinstance(data, pandas.DataFrame) else data,
385
+ media_type='application/json',
386
+ ) # type: ignore
365
387
  except ValueError as e:
366
- raise HTTPException(status_code=400, detail=str(e))
388
+ raise HTTPException(status_code=400, detail=str(e)) from e
367
389
 
368
390
  class DataVariableCountRequestBody(BaseModel):
369
391
  cache_key: Optional[str] = None
@@ -383,12 +405,13 @@ def create_router(config: Configuration):
383
405
 
384
406
  if body is None or body.cache_key is None:
385
407
  raise HTTPException(
386
- status_code=400, detail="Cache key is required when requesting DerivedDataVariable's count"
408
+ status_code=400,
409
+ detail="Cache key is required when requesting DerivedDataVariable's count",
387
410
  )
388
411
 
389
412
  return await variable_def.get_total_count(variable_def, store, body.cache_key, body.filters)
390
413
  except ValueError as e:
391
- raise HTTPException(status_code=400, detail=str(e))
414
+ raise HTTPException(status_code=400, detail=str(e)) from e
392
415
 
393
416
  @core_api_router.get('/data-variable/{uid}/schema', dependencies=[Depends(verify_session)])
394
417
  async def get_data_variable_schema(uid: str, cache_key: Optional[str] = None):
@@ -402,7 +425,8 @@ def create_router(config: Configuration):
402
425
 
403
426
  if cache_key is None:
404
427
  raise HTTPException(
405
- status_code=400, detail='Cache key is required when requesting DerivedDataVariable schema'
428
+ status_code=400,
429
+ detail='Cache key is required when requesting DerivedDataVariable schema',
406
430
  )
407
431
 
408
432
  # Use the other registry for derived variables
@@ -411,7 +435,7 @@ def create_router(config: Configuration):
411
435
  content = json.dumps(jsonable_encoder(data)) if isinstance(data, dict) else data
412
436
  return Response(content=content, media_type='application/json')
413
437
  except ValueError as e:
414
- raise HTTPException(status_code=400, detail=str(e))
438
+ raise HTTPException(status_code=400, detail=str(e)) from e
415
439
 
416
440
  @core_api_router.post('/data/upload', dependencies=[Depends(verify_session)])
417
441
  async def upload_data(
@@ -427,7 +451,10 @@ def create_router(config: Configuration):
427
451
  registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
428
452
 
429
453
  if data_uid is None and resolver_id is None:
430
- raise HTTPException(400, 'Neither resolver_id or data_uid specified, at least one of them is required')
454
+ raise HTTPException(
455
+ 400,
456
+ 'Neither resolver_id or data_uid specified, at least one of them is required',
457
+ )
431
458
 
432
459
  try:
433
460
  # If resolver id is provided, run the custom
@@ -440,16 +467,16 @@ def create_router(config: Configuration):
440
467
 
441
468
  return {'status': 'SUCCESS'}
442
469
  except Exception as e:
443
- raise HTTPException(status_code=400, detail=str(e))
470
+ raise HTTPException(status_code=400, detail=str(e)) from e
444
471
 
445
472
  class DerivedStateRequestBody(BaseModel):
446
473
  values: NormalizedPayload[List[Any]]
447
- force: bool
474
+ force_key: Optional[str] = None
448
475
  ws_channel: str
449
476
  is_data_variable: Optional[bool] = False
450
477
 
451
478
  @core_api_router.post('/derived-variable/{uid}', dependencies=[Depends(verify_session)])
452
- async def get_derived_variable(uid: str, body: DerivedStateRequestBody): # pylint: disable=unused-variable
479
+ async def get_derived_variable(uid: str, body: DerivedStateRequestBody):
453
480
  task_mgr: TaskManager = utils_registry.get('TaskManager')
454
481
  store: CacheStore = utils_registry.get('Store')
455
482
  registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
@@ -457,7 +484,7 @@ def create_router(config: Configuration):
457
484
 
458
485
  values = denormalize(body.values.data, body.values.lookup)
459
486
 
460
- result = await variable_def.get_value(variable_def, store, task_mgr, values, body.force)
487
+ result = await variable_def.get_value(variable_def, store, task_mgr, values, body.force_key)
461
488
 
462
489
  response: Any = result
463
490
 
@@ -466,7 +493,10 @@ def create_router(config: Configuration):
466
493
  if isinstance(result['value'], BaseTask):
467
494
  # Kick off the task
468
495
  await task_mgr.run_task(result['value'], body.ws_channel)
469
- response = {'task_id': result['value'].task_id, 'cache_key': result['cache_key']}
496
+ response = {
497
+ 'task_id': result['value'].task_id,
498
+ 'cache_key': result['cache_key'],
499
+ }
470
500
 
471
501
  dev_logger.debug(
472
502
  f'DerivedVariable {variable_def.uid[:3]}..{variable_def.uid[-3:]}',
@@ -512,7 +542,7 @@ def create_router(config: Configuration):
512
542
  tg.start_soon(_write, store_uid, value)
513
543
 
514
544
  @core_api_router.get('/tasks/{task_id}', dependencies=[Depends(verify_session)])
515
- async def get_task_result(task_id: str): # pylint: disable=unused-variable
545
+ async def get_task_result(task_id: str):
516
546
  try:
517
547
  task_mgr: TaskManager = utils_registry.get('TaskManager')
518
548
  res = await task_mgr.get_result(task_id)
@@ -528,25 +558,28 @@ def create_router(config: Configuration):
528
558
  return Response(df_to_json(res))
529
559
 
530
560
  return res
531
- except Exception as e:
532
- raise ValueError(f'The result for task id {task_id} could not be found').with_traceback(e.__traceback__)
561
+ except Exception as err:
562
+ raise ValueError(f'The result for task id {task_id} could not be found') from err
533
563
 
534
564
  @core_api_router.delete('/tasks/{task_id}', dependencies=[Depends(verify_session)])
535
- async def cancel_task(task_id: str): # pylint: disable=unused-variable
565
+ async def cancel_task(task_id: str):
536
566
  try:
537
567
  task_mgr: TaskManager = utils_registry.get('TaskManager')
538
568
  return await task_mgr.cancel_task(task_id)
539
569
  except TaskManagerError as e:
540
- dev_logger.error(f'The task id {task_id} could not be found, it may have already been cancelled', e)
570
+ dev_logger.error(
571
+ f'The task id {task_id} could not be found, it may have already been cancelled',
572
+ e,
573
+ )
541
574
 
542
575
  @core_api_router.get('/template/{template}', dependencies=[Depends(verify_session)])
543
- async def get_template(template: str): # pylint: disable=unused-variable
576
+ async def get_template(template: str):
544
577
  try:
545
578
  selected_template = template_registry.get(template)
546
579
  normalized_template, lookup = normalize(jsonable_encoder(selected_template))
547
580
  return {'data': normalized_template, 'lookup': lookup}
548
- except KeyError:
549
- raise HTTPException(status_code=404, detail=f'Template: {template}, not found in registry')
581
+ except KeyError as err:
582
+ raise HTTPException(status_code=404, detail=f'Template: {template}, not found in registry') from err
550
583
  except Exception as e:
551
584
  dev_logger.error('Something went wrong while trying to get the template', e)
552
585
 
@@ -20,7 +20,7 @@ from datetime import datetime
20
20
  from multiprocessing import get_context
21
21
  from multiprocessing.process import BaseProcess
22
22
  from pickle import PicklingError
23
- from typing import Any, List, Optional, Union
23
+ from typing import Any, List, Optional, Union, cast
24
24
 
25
25
  from croniter import croniter
26
26
  from pydantic import BaseModel, field_validator
@@ -56,25 +56,28 @@ class ScheduledJob(BaseModel):
56
56
  job_process = ctx.Process(target=self._refresh_timer, args=(func, args), daemon=True)
57
57
  job_process.start()
58
58
  return job_process
59
- except PicklingError:
59
+ except PicklingError as err:
60
60
  raise PicklingError(
61
61
  """
62
62
  Unable to pickle scheduled function. Please ensure that the function you are trying
63
63
  to schedule is not in the same file as the ConfigurationBuilder is defined and that
64
64
  the function is not a lambda.
65
65
  """
66
- )
66
+ ) from err
67
67
 
68
68
  def _refresh_timer(self, func, args):
69
69
  while self.continue_running and not (self.run_once and not self.first_execution):
70
- interval = self.interval
70
+ interval: int
71
71
  # If there's more than one interval to wait, i.e. this is a weekday process
72
- if type(interval) == list:
72
+ if isinstance(self.interval, list):
73
73
  # Wait the first interval if this is the first execution of the job
74
74
  interval = self.interval[0] if self.first_execution else self.interval[1]
75
+ else:
76
+ interval = self.interval
77
+
75
78
  self.first_execution = False
76
79
  # Wait the interval and then run the job
77
- time.sleep(interval)
80
+ time.sleep(cast(int, interval))
78
81
  func(*args)
79
82
 
80
83
 
@@ -176,7 +179,7 @@ class ScheduledJobFactory(BaseModel):
176
179
 
177
180
  @field_validator('weekday', mode='before')
178
181
  @classmethod
179
- def validate_weekday(cls, weekday: Any) -> datetime: # pylint: disable=E0213
182
+ def validate_weekday(cls, weekday: Any) -> datetime:
180
183
  if isinstance(weekday, datetime):
181
184
  return weekday
182
185
  if isinstance(weekday, str):
@@ -283,7 +286,9 @@ class Scheduler:
283
286
  def _weekday(self, weekday: int):
284
287
  # The job must execute on a weekly interval
285
288
  return ScheduledJobFactory(
286
- interval=self.interval * 604800, run_once=self._run_once, weekday=str(weekday) # type: ignore
289
+ interval=self.interval * 604800,
290
+ run_once=self._run_once,
291
+ weekday=str(weekday), # type: ignore
287
292
  )
288
293
 
289
294
  def monday(self) -> ScheduledJobFactory:
@@ -74,7 +74,7 @@ def generate_env_file(filename='.env'):
74
74
  f.write(env_content)
75
75
 
76
76
 
77
- @lru_cache()
77
+ @lru_cache
78
78
  def get_settings():
79
79
  """
80
80
  Get a cached instance of the settings, loading values from the .env if present.
@@ -83,7 +83,7 @@ def get_settings():
83
83
 
84
84
  # Test purposes - if DARA_TEST_FLAG is set then override env with .env.test
85
85
  if os.environ.get('DARA_TEST_FLAG', None) is not None:
86
- return Settings(**dotenv_values('.env.test'))
86
+ return Settings(**dotenv_values('.env.test')) # type: ignore
87
87
 
88
88
  env_error = False
89
89
 
@@ -15,9 +15,11 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
+ import contextlib
18
19
  import inspect
19
20
  import math
20
- from typing import Any, Awaitable, Callable, Dict, List, Optional, Union, overload
21
+ from collections.abc import Awaitable
22
+ from typing import Any, Callable, Dict, List, Optional, Union, overload
21
23
 
22
24
  from anyio import (
23
25
  CancelScope,
@@ -137,14 +139,12 @@ class Task(BaseTask):
137
139
 
138
140
  async def on_progress(progress: float, msg: str):
139
141
  if send_stream is not None:
140
- try:
142
+ with contextlib.suppress(ClosedResourceError):
141
143
  await send_stream.send(TaskProgressUpdate(task_id=self.task_id, progress=progress, message=msg))
142
- except ClosedResourceError:
143
- pass
144
144
 
145
145
  async def on_result(result: Any):
146
146
  if send_stream is not None:
147
- try:
147
+ with contextlib.suppress(ClosedResourceError):
148
148
  await send_stream.send(
149
149
  TaskResult(
150
150
  task_id=self.task_id,
@@ -153,19 +153,15 @@ class Task(BaseTask):
153
153
  reg_entry=self.reg_entry,
154
154
  )
155
155
  )
156
- except ClosedResourceError:
157
- pass
158
156
 
159
157
  async def on_error(exc: BaseException):
160
158
  if send_stream is not None:
161
- try:
159
+ with contextlib.suppress(ClosedResourceError):
162
160
  await send_stream.send(
163
161
  TaskError(
164
162
  task_id=self.task_id, error=exc, cache_key=self.cache_key, reg_entry=self.reg_entry
165
163
  )
166
164
  )
167
- except ClosedResourceError:
168
- pass
169
165
 
170
166
  with pool.on_progress(self.task_id, on_progress):
171
167
  pool_task_def = pool.submit(self.task_id, self._func_name, args=tuple(self._args), kwargs=self._kwargs)
@@ -362,12 +358,10 @@ class TaskManager:
362
358
  self.store = store
363
359
 
364
360
  @overload
365
- async def run_task(self, task: PendingTask, ws_channel: Optional[str] = None) -> Any:
366
- ...
361
+ async def run_task(self, task: PendingTask, ws_channel: Optional[str] = None) -> Any: ...
367
362
 
368
363
  @overload
369
- async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None) -> PendingTask:
370
- ...
364
+ async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None) -> PendingTask: ...
371
365
 
372
366
  async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None):
373
367
  """
@@ -20,6 +20,7 @@ from __future__ import annotations
20
20
  import asyncio
21
21
  import inspect
22
22
  import os
23
+ from collections.abc import Awaitable, Coroutine, Sequence
23
24
  from functools import wraps
24
25
  from importlib import import_module
25
26
  from importlib.util import find_spec
@@ -27,13 +28,10 @@ from types import ModuleType
27
28
  from typing import (
28
29
  TYPE_CHECKING,
29
30
  Any,
30
- Awaitable,
31
31
  Callable,
32
- Coroutine,
33
32
  Dict,
34
33
  Literal,
35
34
  Optional,
36
- Sequence,
37
35
  Tuple,
38
36
  Union,
39
37
  )
@@ -79,7 +77,7 @@ def get_cache_scope(cache_type: Optional[CacheType]) -> CacheScope:
79
77
  return 'global'
80
78
 
81
79
 
82
- async def run_user_handler(handler: Callable, args: Sequence = [], kwargs: dict = {}):
80
+ async def run_user_handler(handler: Callable, args: Union[Sequence, None] = None, kwargs: Union[dict, None] = None):
83
81
  """
84
82
  Run a user-defined handler function. Runs sync functions in a threadpool.
85
83
  Handles SystemExits cleanly.
@@ -88,6 +86,10 @@ async def run_user_handler(handler: Callable, args: Sequence = [], kwargs: dict
88
86
  :param args: list of arguments to pass to the function
89
87
  :param kwargs: dict of kwargs to past to the function
90
88
  """
89
+ if args is None:
90
+ args = []
91
+ if kwargs is None:
92
+ kwargs = {}
91
93
  with handle_system_exit('User defined function quit unexpectedly'):
92
94
  if inspect.iscoroutinefunction(handler):
93
95
  return await handler(*args, **kwargs)
@@ -164,14 +166,14 @@ def enforce_sso(conf: ConfigurationBuilder):
164
166
  Raises if SSO is not used
165
167
  """
166
168
  try:
167
- from dara.enterprise import SSOAuthConfig
169
+ from dara.enterprise import SSOAuthConfig # pyright: ignore[reportMissingImports]
168
170
 
169
171
  if conf.auth_config is None or not isinstance(conf.auth_config, SSOAuthConfig):
170
172
  raise ValueError('Config does not have SSO auth enabled. Please update your application to configure SSO.')
171
- except ImportError:
173
+ except ImportError as err:
172
174
  raise ValueError(
173
175
  'SSO is not enabled. Please install the dara_enterprise package and configure SSO to use this feature.'
174
- )
176
+ ) from err
175
177
 
176
178
 
177
179
  def async_dedupe(fn: Callable[..., Awaitable]):
@@ -228,8 +230,7 @@ def resolve_exception_group(error: Any):
228
230
 
229
231
  :param error: The error to resolve
230
232
  """
231
- if isinstance(error, ExceptionGroup):
232
- if len(error.exceptions) == 1:
233
- return resolve_exception_group(error.exceptions[0])
233
+ if isinstance(error, ExceptionGroup) and len(error.exceptions) == 1:
234
+ return resolve_exception_group(error.exceptions[0])
234
235
 
235
236
  return error
@@ -204,28 +204,27 @@ class WebSocketHandler:
204
204
  message_id = message.channel
205
205
 
206
206
  # If the message has a channel ID, it's a response to a previous message
207
- if message_id:
208
- if message_id in self.pending_responses:
209
- event, existing_messages = self.pending_responses[message_id]
210
-
211
- # If the response is chunked then collect the messages in pending responses
212
- if message.chunk_count is not None:
213
- if existing_messages is not None and isinstance(existing_messages, list):
214
- existing_messages.append(message.message)
215
- else:
216
- existing_messages = [message.message]
217
- self.pending_responses[message_id] = (
218
- event,
219
- existing_messages,
220
- )
207
+ if message_id and message_id in self.pending_responses:
208
+ event, existing_messages = self.pending_responses[message_id]
221
209
 
222
- # If all chunks have been received, set the event to notify the waiting coroutine
223
- if len(existing_messages) == message.chunk_count:
224
- event.set()
210
+ # If the response is chunked then collect the messages in pending responses
211
+ if message.chunk_count is not None:
212
+ if existing_messages is not None and isinstance(existing_messages, list):
213
+ existing_messages.append(message.message)
225
214
  else:
226
- # Store the response and set the event to notify the waiting coroutine
227
- self.pending_responses[message_id] = (event, message.message)
215
+ existing_messages = [message.message]
216
+ self.pending_responses[message_id] = (
217
+ event,
218
+ existing_messages,
219
+ )
220
+
221
+ # If all chunks have been received, set the event to notify the waiting coroutine
222
+ if len(existing_messages) == message.chunk_count:
228
223
  event.set()
224
+ else:
225
+ # Store the response and set the event to notify the waiting coroutine
226
+ self.pending_responses[message_id] = (event, message.message)
227
+ event.set()
229
228
 
230
229
  return None
231
230