dara-core 1.21.0__py3-none-any.whl → 1.21.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,26 +16,34 @@ limitations under the License.
16
16
  """
17
17
 
18
18
  import inspect
19
+ import json
20
+ import math
19
21
  import os
22
+ import traceback
20
23
  from collections.abc import Mapping
21
24
  from functools import wraps
22
25
  from importlib.metadata import version
23
- from typing import Any, Callable, Dict, List, Optional
26
+ from typing import Annotated, Any, Callable, Dict, List, Literal, Optional, Union
27
+ from urllib.parse import unquote
24
28
 
25
29
  import anyio
30
+ from anyio.streams.memory import MemoryObjectSendStream
26
31
  from fastapi import (
27
32
  APIRouter,
28
33
  Body,
29
34
  Depends,
35
+ FastAPI,
30
36
  File,
31
37
  Form,
32
38
  HTTPException,
39
+ Path,
33
40
  Response,
34
41
  UploadFile,
35
42
  )
43
+ from fastapi.encoders import jsonable_encoder
36
44
  from fastapi.responses import StreamingResponse
37
45
  from pandas import DataFrame
38
- from pydantic import BaseModel
46
+ from pydantic import BaseModel, Field
39
47
  from starlette.background import BackgroundTask
40
48
  from starlette.status import HTTP_415_UNSUPPORTED_MEDIA_TYPE
41
49
 
@@ -46,9 +54,10 @@ from dara.core.interactivity.any_data_variable import upload
46
54
  from dara.core.interactivity.filtering import FilterQuery, Pagination
47
55
  from dara.core.interactivity.server_variable import ServerVariable
48
56
  from dara.core.internal.cache_store import CacheStore
57
+ from dara.core.internal.devtools import print_stacktrace
49
58
  from dara.core.internal.download import DownloadRegistryEntry
50
- from dara.core.internal.execute_action import CURRENT_ACTION_ID
51
- from dara.core.internal.normalization import NormalizedPayload, denormalize
59
+ from dara.core.internal.execute_action import CURRENT_ACTION_ID, execute_action_sync
60
+ from dara.core.internal.normalization import NormalizedPayload, denormalize, normalize
52
61
  from dara.core.internal.pandas_utils import data_response_to_json, df_to_json, is_data_response
53
62
  from dara.core.internal.registries import (
54
63
  action_def_registry,
@@ -103,385 +112,585 @@ def error_decorator(handler: Callable[..., Any]):
103
112
  return _inner_func
104
113
 
105
114
 
106
- def create_router(config: Configuration):
107
- """
108
- Create the main Dara core API router
115
+ core_api_router = APIRouter()
109
116
 
110
- :param config: Dara app configuration
111
- """
112
- core_api_router = APIRouter()
113
117
 
114
- @core_api_router.get('/actions', dependencies=[Depends(verify_session)])
115
- async def get_actions():
116
- return action_def_registry.get_all().items()
118
+ @core_api_router.get('/actions', dependencies=[Depends(verify_session)])
119
+ async def get_actions():
120
+ return action_def_registry.get_all().items()
117
121
 
118
- class ActionRequestBody(BaseModel):
119
- values: NormalizedPayload[Mapping[str, Any]]
120
- """Dynamic kwarg values"""
121
122
 
122
- input: Any = None
123
- """Input from the component"""
123
+ class ActionRequestBody(BaseModel):
124
+ values: NormalizedPayload[Mapping[str, Any]]
125
+ """Dynamic kwarg values"""
124
126
 
125
- ws_channel: str
126
- """Websocket channel assigned to the client"""
127
+ input: Any = None
128
+ """Input from the component"""
127
129
 
128
- uid: str
129
- """Instance uid"""
130
+ ws_channel: str
131
+ """Websocket channel assigned to the client"""
130
132
 
131
- execution_id: str
132
- """Execution id, unique to this request"""
133
+ uid: str
134
+ """Instance uid"""
133
135
 
134
- @core_api_router.post('/action/{uid}', dependencies=[Depends(verify_session)])
135
- async def get_action(uid: str, body: ActionRequestBody):
136
- store: CacheStore = utils_registry.get('Store')
137
- task_mgr: TaskManager = utils_registry.get('TaskManager')
138
- registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
139
- action_def: ActionResolverDef = await registry_mgr.get(action_registry, uid)
136
+ execution_id: str
137
+ """Execution id, unique to this request"""
140
138
 
141
- CURRENT_ACTION_ID.set(body.uid)
142
- WS_CHANNEL.set(body.ws_channel)
143
139
 
144
- # Denormalize the values
145
- values = denormalize(body.values.data, body.values.lookup)
140
+ @core_api_router.post('/action/{uid}', dependencies=[Depends(verify_session)])
141
+ async def get_action(uid: str, body: ActionRequestBody):
142
+ store: CacheStore = utils_registry.get('Store')
143
+ task_mgr: TaskManager = utils_registry.get('TaskManager')
144
+ registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
145
+ action_def: ActionResolverDef = await registry_mgr.get(action_registry, uid)
146
+
147
+ CURRENT_ACTION_ID.set(body.uid)
148
+ WS_CHANNEL.set(body.ws_channel)
149
+
150
+ # Denormalize the values
151
+ values = denormalize(body.values.data, body.values.lookup)
152
+
153
+ # Fetch static kwargs
154
+ static_kwargs = await registry_mgr.get(static_kwargs_registry, body.uid)
155
+
156
+ # Execute the action - kick off a background task to run the handler
157
+ response = await action_def.execute_action(
158
+ action_def,
159
+ body.input,
160
+ values,
161
+ static_kwargs,
162
+ body.execution_id,
163
+ body.ws_channel,
164
+ store,
165
+ task_mgr,
166
+ )
167
+
168
+ if isinstance(response, BaseTask):
169
+ await task_mgr.run_task(response, body.ws_channel)
170
+ return {'task_id': response.task_id}
171
+
172
+ return {'execution_id': response}
173
+
174
+
175
+ @core_api_router.get('/download') # explicitly unauthenticated
176
+ async def get_download(code: str):
177
+ store: CacheStore = utils_registry.get('Store')
146
178
 
147
- # Fetch static kwargs
179
+ try:
180
+ data_entry = await store.get(DownloadRegistryEntry, key=code)
181
+
182
+ # If not found directly in the store, use the override registry
183
+ # to check if we can get the download entry from there
184
+ if data_entry is None:
185
+ registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
186
+ # NOTE: This will throw a Value/KeyError if the code is not found so no need to rethrow
187
+ data_entry = await registry_mgr.get(download_code_registry, code)
188
+ # We managed to find one from the lookup,
189
+ # remove it from the registry immediately because it's one time use
190
+ download_code_registry.remove(code)
191
+
192
+ async_file, cleanup = await data_entry.download(data_entry)
193
+
194
+ file_name = os.path.basename(data_entry.file_path)
195
+
196
+ # This mirrors builtin's FastAPI FileResponse implementation
197
+ async def stream_file():
198
+ has_content = True
199
+ chunk_size = 64 * 1024
200
+ while has_content:
201
+ chunk = await async_file.read(chunk_size)
202
+ has_content = chunk_size == len(chunk)
203
+ yield chunk
204
+
205
+ return StreamingResponse(
206
+ content=stream_file(),
207
+ headers={'Content-Disposition': f'attachment; filename={file_name}'},
208
+ background=BackgroundTask(cleanup),
209
+ )
210
+
211
+ except (KeyError, ValueError) as e:
212
+ raise ValueError('Invalid or expired download code') from e
213
+
214
+
215
+ @core_api_router.get('/components/{name}/definition', dependencies=[Depends(verify_session)])
216
+ async def get_component_definition(name: str):
217
+ """
218
+ Attempt to refetch a component definition from the backend.
219
+ This is used when a component isn't immediately available in the initial registry,
220
+ e.g. when it was added by a py_component.
221
+
222
+ :param name: the name of component
223
+ """
224
+ registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
225
+ component = await registry_mgr.get(component_registry, name)
226
+ return component.model_dump(exclude={'func'})
227
+
228
+
229
+ class ComponentRequestBody(BaseModel):
230
+ # Dynamic kwarg values
231
+ values: NormalizedPayload[Mapping[str, Any]]
232
+ # Instance uid
233
+ uid: str
234
+ # Websocket channel assigned to the client
235
+ ws_channel: str
236
+
237
+
238
+ @core_api_router.post('/components/{component}', dependencies=[Depends(verify_session)])
239
+ async def get_component(component: str, body: ComponentRequestBody):
240
+ CURRENT_COMPONENT_ID.set(body.uid)
241
+ WS_CHANNEL.set(body.ws_channel)
242
+ store: CacheStore = utils_registry.get('Store')
243
+ task_mgr: TaskManager = utils_registry.get('TaskManager')
244
+ registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
245
+ comp_def = await registry_mgr.get(component_registry, component)
246
+
247
+ if isinstance(comp_def, PyComponentDef):
148
248
  static_kwargs = await registry_mgr.get(static_kwargs_registry, body.uid)
249
+ values = denormalize(body.values.data, body.values.lookup)
250
+
251
+ response = await comp_def.render_component(comp_def, store, task_mgr, values, static_kwargs)
149
252
 
150
- # Execute the action - kick off a background task to run the handler
151
- response = await action_def.execute_action(
152
- action_def,
153
- body.input,
154
- values,
155
- static_kwargs,
156
- body.execution_id,
157
- body.ws_channel,
158
- store,
159
- task_mgr,
253
+ dev_logger.debug(
254
+ f'PyComponent {comp_def.func.__name__ if comp_def.func else "anonymous"}',
255
+ 'return value',
256
+ {'value': response},
160
257
  )
161
258
 
162
259
  if isinstance(response, BaseTask):
163
260
  await task_mgr.run_task(response, body.ws_channel)
164
261
  return {'task_id': response.task_id}
165
262
 
166
- return {'execution_id': response}
263
+ return response
264
+
265
+ raise HTTPException(status_code=400, detail='Requesting this type of component is not supported')
266
+
167
267
 
168
- @core_api_router.get('/download') # explicitly unauthenticated
169
- async def get_download(code: str):
268
+ @core_api_router.get('/derived-variable/{uid}/latest', dependencies=[Depends(verify_session)])
269
+ async def get_latest_derived_variable(uid: str):
270
+ try:
170
271
  store: CacheStore = utils_registry.get('Store')
272
+ latest_value_entry = latest_value_registry.get(uid)
273
+ variable_entry = derived_variable_registry.get(uid)
171
274
 
172
- try:
173
- data_entry = await store.get(DownloadRegistryEntry, key=code)
174
-
175
- # If not found directly in the store, use the override registry
176
- # to check if we can get the download entry from there
177
- if data_entry is None:
178
- registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
179
- # NOTE: This will throw a Value/KeyError if the code is not found so no need to rethrow
180
- data_entry = await registry_mgr.get(download_code_registry, code)
181
- # We managed to find one from the lookup,
182
- # remove it from the registry immediately because it's one time use
183
- download_code_registry.remove(code)
184
-
185
- async_file, cleanup = await data_entry.download(data_entry)
186
-
187
- file_name = os.path.basename(data_entry.file_path)
188
-
189
- # This mirrors builtin's FastAPI FileResponse implementation
190
- async def stream_file():
191
- has_content = True
192
- chunk_size = 64 * 1024
193
- while has_content:
194
- chunk = await async_file.read(chunk_size)
195
- has_content = chunk_size == len(chunk)
196
- yield chunk
197
-
198
- return StreamingResponse(
199
- content=stream_file(),
200
- headers={'Content-Disposition': f'attachment; filename={file_name}'},
201
- background=BackgroundTask(cleanup),
202
- )
203
-
204
- except (KeyError, ValueError) as e:
205
- raise ValueError('Invalid or expired download code') from e
206
-
207
- @core_api_router.get('/components/{name}/definition', dependencies=[Depends(verify_session)])
208
- async def get_component_definition(name: str):
209
- """
210
- Attempt to refetch a component definition from the backend.
211
- This is used when a component isn't immediately available in the initial registry,
212
- e.g. when it was added by a py_component.
213
-
214
- :param name: the name of component
215
- """
275
+ # Lookup the latest key in the cache
276
+ scope = get_cache_scope(variable_entry.cache.cache_type if variable_entry.cache else None)
277
+ latest_key = await store.get(latest_value_entry, key=scope)
278
+
279
+ if latest_key is None:
280
+ return None
281
+
282
+ # Lookup latest value for that key
283
+ latest_value = await store.get_or_wait(variable_entry, key=latest_key)
284
+
285
+ dev_logger.debug(
286
+ f'DerivedVariable {variable_entry.uid[:3]}..{variable_entry.uid[-3:]}',
287
+ 'latest value',
288
+ {'value': latest_value, 'uid': uid},
289
+ )
290
+ return latest_value
291
+
292
+ except KeyError as err:
293
+ raise ValueError(f'Could not find latest value for derived variable with uid: {uid}') from err
294
+
295
+
296
+ class TabularRequestBody(BaseModel):
297
+ filters: Optional[FilterQuery] = None
298
+ ws_channel: str
299
+ dv_values: Optional[NormalizedPayload[List[Any]]] = None
300
+ """DerivedVariable values if variable is a DerivedVariable"""
301
+ force_key: Optional[str] = None
302
+ """Optional force key if variable is a DerivedVariable and a recalculation is forced"""
303
+
304
+
305
+ @core_api_router.post('/tabular-variable/{uid}', dependencies=[Depends(verify_session)])
306
+ async def get_tabular_variable(
307
+ uid: str,
308
+ body: TabularRequestBody,
309
+ offset: Optional[int] = None,
310
+ limit: Optional[int] = None,
311
+ order_by: Optional[str] = None,
312
+ index: Optional[str] = None,
313
+ ):
314
+ """
315
+ Generic endpoint for getting tabular data from a variable.
316
+ Supports ServerVariables and DerivedVariables.
317
+ """
318
+ WS_CHANNEL.set(body.ws_channel)
319
+
320
+ try:
321
+ pagination = Pagination(offset=offset, limit=limit, orderBy=order_by, index=index)
216
322
  registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
217
- component = await registry_mgr.get(component_registry, name)
218
- return component.model_dump(exclude={'func'})
219
-
220
- class ComponentRequestBody(BaseModel):
221
- # Dynamic kwarg values
222
- values: NormalizedPayload[Mapping[str, Any]]
223
- # Instance uid
224
- uid: str
225
- # Websocket channel assigned to the client
226
- ws_channel: str
227
-
228
- @core_api_router.post('/components/{component}', dependencies=[Depends(verify_session)])
229
- async def get_component(component: str, body: ComponentRequestBody):
230
- CURRENT_COMPONENT_ID.set(body.uid)
231
- WS_CHANNEL.set(body.ws_channel)
323
+
324
+ # ServerVariable
325
+ if body.dv_values is None:
326
+ server_variable_entry = await registry_mgr.get(server_variable_registry, uid)
327
+ data_response = await ServerVariable.get_tabular_data(server_variable_entry, body.filters, pagination)
328
+ return Response(data_response_to_json(data_response), media_type='application/json')
329
+
330
+ # DerivedVariable
232
331
  store: CacheStore = utils_registry.get('Store')
233
332
  task_mgr: TaskManager = utils_registry.get('TaskManager')
234
- registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
235
- comp_def = await registry_mgr.get(component_registry, component)
333
+ variable_def = await registry_mgr.get(derived_variable_registry, uid)
334
+ values = denormalize(body.dv_values.data, body.dv_values.lookup)
236
335
 
237
- if isinstance(comp_def, PyComponentDef):
238
- static_kwargs = await registry_mgr.get(static_kwargs_registry, body.uid)
239
- values = denormalize(body.values.data, body.values.lookup)
336
+ result = await variable_def.get_tabular_data(
337
+ variable_def, store, task_mgr, values, body.force_key, pagination, body.filters
338
+ )
240
339
 
241
- response = await comp_def.render_component(comp_def, store, task_mgr, values, static_kwargs)
340
+ if isinstance(result, BaseTask):
341
+ await task_mgr.run_task(result, body.ws_channel)
342
+ return {'task_id': result.task_id}
242
343
 
243
- dev_logger.debug(
244
- f'PyComponent {comp_def.func.__name__ if comp_def.func else "anonymous"}',
245
- 'return value',
246
- {'value': response},
247
- )
344
+ return Response(data_response_to_json(result), media_type='application/json')
345
+ except NonTabularDataError as e:
346
+ raise HTTPException(status_code=HTTP_415_UNSUPPORTED_MEDIA_TYPE, detail=str(e)) from e
248
347
 
249
- if isinstance(response, BaseTask):
250
- await task_mgr.run_task(response, body.ws_channel)
251
- return {'task_id': response.task_id}
252
348
 
253
- return response
349
+ @core_api_router.get('/server-variable/{uid}/sequence', dependencies=[Depends(verify_session)])
350
+ async def get_server_variable_sequence(
351
+ uid: str,
352
+ ):
353
+ registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
354
+ server_variable_entry = await registry_mgr.get(server_variable_registry, uid)
355
+ seq_num = await ServerVariable.get_sequence_number(server_variable_entry)
356
+ return {'sequence_number': seq_num}
254
357
 
255
- raise HTTPException(status_code=400, detail='Requesting this type of component is not supported')
256
358
 
257
- @core_api_router.get('/derived-variable/{uid}/latest', dependencies=[Depends(verify_session)])
258
- async def get_latest_derived_variable(uid: str):
259
- try:
260
- store: CacheStore = utils_registry.get('Store')
261
- latest_value_entry = latest_value_registry.get(uid)
262
- variable_entry = derived_variable_registry.get(uid)
263
-
264
- # Lookup the latest key in the cache
265
- scope = get_cache_scope(variable_entry.cache.cache_type if variable_entry.cache else None)
266
- latest_key = await store.get(latest_value_entry, key=scope)
267
-
268
- if latest_key is None:
269
- return None
270
-
271
- # Lookup latest value for that key
272
- latest_value = await store.get_or_wait(variable_entry, key=latest_key)
273
-
274
- dev_logger.debug(
275
- f'DerivedVariable {variable_entry.uid[:3]}..{variable_entry.uid[-3:]}',
276
- 'latest value',
277
- {'value': latest_value, 'uid': uid},
278
- )
279
- return latest_value
280
-
281
- except KeyError as err:
282
- raise ValueError(f'Could not find latest value for derived variable with uid: {uid}') from err
283
-
284
- class TabularRequestBody(BaseModel):
285
- filters: Optional[FilterQuery] = None
286
- ws_channel: str
287
- dv_values: Optional[NormalizedPayload[List[Any]]] = None
288
- """DerivedVariable values if variable is a DerivedVariable"""
289
- force_key: Optional[str] = None
290
- """Optional force key if variable is a DerivedVariable and a recalculation is forced"""
291
-
292
- @core_api_router.post('/tabular-variable/{uid}', dependencies=[Depends(verify_session)])
293
- async def get_tabular_variable(
294
- uid: str,
295
- body: TabularRequestBody,
296
- offset: Optional[int] = None,
297
- limit: Optional[int] = None,
298
- order_by: Optional[str] = None,
299
- index: Optional[str] = None,
300
- ):
301
- """
302
- Generic endpoint for getting tabular data from a variable.
303
- Supports ServerVariables and DerivedVariables.
304
- """
305
- WS_CHANNEL.set(body.ws_channel)
359
+ @core_api_router.post('/data/upload', dependencies=[Depends(verify_session)])
360
+ async def upload_data(
361
+ data_uid: Optional[str] = None,
362
+ data: UploadFile = File(),
363
+ resolver_id: Optional[str] = Form(default=None),
364
+ ):
365
+ """
366
+ Upload endpoint.
367
+ Can run a custom resolver_id (if previously registered, otherwise runs a default one)
368
+ and update a data variable with its return value (if target is specified).
369
+ """
370
+ registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
306
371
 
307
- try:
308
- pagination = Pagination(offset=offset, limit=limit, orderBy=order_by, index=index)
309
- registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
372
+ if data_uid is None and resolver_id is None:
373
+ raise HTTPException(
374
+ 400,
375
+ 'Neither resolver_id or data_uid specified, at least one of them is required',
376
+ )
310
377
 
311
- # ServerVariable
312
- if body.dv_values is None:
313
- server_variable_entry = await registry_mgr.get(server_variable_registry, uid)
314
- data_response = await ServerVariable.get_tabular_data(server_variable_entry, body.filters, pagination)
315
- return Response(data_response_to_json(data_response), media_type='application/json')
378
+ try:
379
+ # If resolver id is provided, run the custom
380
+ if resolver_id:
381
+ upload_resolver_def: UploadResolverDef = await registry_mgr.get(upload_resolver_registry, resolver_id)
382
+ await upload_resolver_def.upload(data, data_uid, resolver_id)
383
+ else:
384
+ # Run the default logic as a fallback, e.g. programmatic upload
385
+ await upload(data, data_uid, resolver_id)
316
386
 
317
- # DerivedVariable
318
- store: CacheStore = utils_registry.get('Store')
319
- task_mgr: TaskManager = utils_registry.get('TaskManager')
320
- variable_def = await registry_mgr.get(derived_variable_registry, uid)
321
- values = denormalize(body.dv_values.data, body.dv_values.lookup)
387
+ return {'status': 'SUCCESS'}
388
+ except Exception as e:
389
+ raise HTTPException(status_code=400, detail=str(e)) from e
322
390
 
323
- result = await variable_def.get_tabular_data(
324
- variable_def, store, task_mgr, values, body.force_key, pagination, body.filters
325
- )
326
391
 
327
- if isinstance(result, BaseTask):
328
- await task_mgr.run_task(result, body.ws_channel)
329
- return {'task_id': result.task_id}
392
+ class DerivedStateRequestBody(BaseModel):
393
+ values: NormalizedPayload[List[Any]]
394
+ force_key: Optional[str] = None
395
+ ws_channel: str
330
396
 
331
- return Response(data_response_to_json(result), media_type='application/json')
332
- except NonTabularDataError as e:
333
- raise HTTPException(status_code=HTTP_415_UNSUPPORTED_MEDIA_TYPE, detail=str(e)) from e
334
397
 
335
- @core_api_router.get('/server-variable/{uid}/sequence', dependencies=[Depends(verify_session)])
336
- async def get_server_variable_sequence(
337
- uid: str,
338
- ):
339
- registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
340
- server_variable_entry = await registry_mgr.get(server_variable_registry, uid)
341
- seq_num = await ServerVariable.get_sequence_number(server_variable_entry)
342
- return {'sequence_number': seq_num}
343
-
344
- @core_api_router.post('/data/upload', dependencies=[Depends(verify_session)])
345
- async def upload_data(
346
- data_uid: Optional[str] = None,
347
- data: UploadFile = File(),
348
- resolver_id: Optional[str] = Form(default=None),
349
- ):
350
- """
351
- Upload endpoint.
352
- Can run a custom resolver_id (if previously registered, otherwise runs a default one)
353
- and update a data variable with its return value (if target is specified).
354
- """
355
- registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
398
+ @core_api_router.post('/derived-variable/{uid}', dependencies=[Depends(verify_session)])
399
+ async def get_derived_variable(uid: str, body: DerivedStateRequestBody):
400
+ task_mgr: TaskManager = utils_registry.get('TaskManager')
401
+ store: CacheStore = utils_registry.get('Store')
402
+ registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
403
+ variable_def = await registry_mgr.get(derived_variable_registry, uid)
356
404
 
357
- if data_uid is None and resolver_id is None:
358
- raise HTTPException(
359
- 400,
360
- 'Neither resolver_id or data_uid specified, at least one of them is required',
361
- )
405
+ values = denormalize(body.values.data, body.values.lookup)
362
406
 
363
- try:
364
- # If resolver id is provided, run the custom
365
- if resolver_id:
366
- upload_resolver_def: UploadResolverDef = await registry_mgr.get(upload_resolver_registry, resolver_id)
367
- await upload_resolver_def.upload(data, data_uid, resolver_id)
368
- else:
369
- # Run the default logic as a fallback, e.g. programmatic upload
370
- await upload(data, data_uid, resolver_id)
371
-
372
- return {'status': 'SUCCESS'}
373
- except Exception as e:
374
- raise HTTPException(status_code=400, detail=str(e)) from e
375
-
376
- class DerivedStateRequestBody(BaseModel):
377
- values: NormalizedPayload[List[Any]]
378
- force_key: Optional[str] = None
379
- ws_channel: str
380
-
381
- @core_api_router.post('/derived-variable/{uid}', dependencies=[Depends(verify_session)])
382
- async def get_derived_variable(uid: str, body: DerivedStateRequestBody):
383
- task_mgr: TaskManager = utils_registry.get('TaskManager')
384
- store: CacheStore = utils_registry.get('Store')
385
- registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
386
- variable_def = await registry_mgr.get(derived_variable_registry, uid)
407
+ result = await variable_def.get_value(variable_def, store, task_mgr, values, body.force_key)
387
408
 
388
- values = denormalize(body.values.data, body.values.lookup)
409
+ response: Any = result
389
410
 
390
- result = await variable_def.get_value(variable_def, store, task_mgr, values, body.force_key)
411
+ WS_CHANNEL.set(body.ws_channel)
391
412
 
392
- response: Any = result
413
+ if isinstance(result['value'], BaseTask):
414
+ # Kick off the task
415
+ await task_mgr.run_task(result['value'], body.ws_channel)
416
+ response = {
417
+ 'task_id': result['value'].task_id,
418
+ 'cache_key': result['cache_key'],
419
+ }
393
420
 
394
- WS_CHANNEL.set(body.ws_channel)
421
+ dev_logger.debug(
422
+ f'DerivedVariable {variable_def.uid[:3]}..{variable_def.uid[-3:]}',
423
+ 'return value',
424
+ {'value': response, 'uid': uid},
425
+ )
395
426
 
396
- if isinstance(result['value'], BaseTask):
397
- # Kick off the task
398
- await task_mgr.run_task(result['value'], body.ws_channel)
399
- response = {
400
- 'task_id': result['value'].task_id,
401
- 'cache_key': result['cache_key'],
402
- }
427
+ # Return {cache_key: <cache_key>, value: <value>}
428
+ return response
403
429
 
404
- dev_logger.debug(
405
- f'DerivedVariable {variable_def.uid[:3]}..{variable_def.uid[-3:]}',
406
- 'return value',
407
- {'value': response, 'uid': uid},
408
- )
409
430
 
410
- # Return {cache_key: <cache_key>, value: <value>}
411
- return response
431
+ @core_api_router.get('/store/{store_uid}', dependencies=[Depends(verify_session)])
432
+ async def read_backend_store(store_uid: str):
433
+ registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
434
+ store_entry: BackendStoreEntry = await registry_mgr.get(backend_store_registry, store_uid)
435
+ result = store_entry.store.read()
412
436
 
413
- @core_api_router.get('/store/{store_uid}', dependencies=[Depends(verify_session)])
414
- async def read_backend_store(store_uid: str):
415
- registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
437
+ # Backend implementation could return a coroutine
438
+ if inspect.iscoroutine(result):
439
+ result = await result
440
+
441
+ # Get the current key and sequence number for this store
442
+ store = store_entry.store
443
+ key = await store._get_key()
444
+ sequence_number = store.sequence_number.get(key, 0)
445
+
446
+ return {'value': result, 'sequence_number': sequence_number}
447
+
448
+
449
+ @core_api_router.post('/store', dependencies=[Depends(verify_session)])
450
+ async def sync_backend_store(ws_channel: str = Body(), values: Dict[str, Any] = Body()):
451
+ registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
452
+
453
+ async def _write(store_uid: str, value: Any):
454
+ WS_CHANNEL.set(ws_channel)
416
455
  store_entry: BackendStoreEntry = await registry_mgr.get(backend_store_registry, store_uid)
417
- result = store_entry.store.read()
456
+ result = store_entry.store.write(value, ignore_channel=ws_channel)
418
457
 
419
458
  # Backend implementation could return a coroutine
420
459
  if inspect.iscoroutine(result):
421
- result = await result
460
+ await result
422
461
 
423
- # Get the current key and sequence number for this store
424
- store = store_entry.store
425
- key = await store._get_key()
426
- sequence_number = store.sequence_number.get(key, 0)
462
+ async with anyio.create_task_group() as tg:
463
+ for store_uid, value in values.items():
464
+ tg.start_soon(_write, store_uid, value)
427
465
 
428
- return {'value': result, 'sequence_number': sequence_number}
429
466
 
430
- @core_api_router.post('/store', dependencies=[Depends(verify_session)])
431
- async def sync_backend_store(ws_channel: str = Body(), values: Dict[str, Any] = Body()):
432
- registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
467
+ @core_api_router.get('/tasks/{task_id}', dependencies=[Depends(verify_session)])
468
+ async def get_task_result(task_id: str):
469
+ try:
470
+ task_mgr: TaskManager = utils_registry.get('TaskManager')
471
+ res = await task_mgr.get_result(task_id)
433
472
 
434
- async def _write(store_uid: str, value: Any):
435
- WS_CHANNEL.set(ws_channel)
436
- store_entry: BackendStoreEntry = await registry_mgr.get(backend_store_registry, store_uid)
437
- result = store_entry.store.write(value, ignore_channel=ws_channel)
473
+ dev_logger.debug(
474
+ f'Retrieving result for Task {task_id}',
475
+ 'return value',
476
+ {'value': res},
477
+ )
438
478
 
439
- # Backend implementation could return a coroutine
440
- if inspect.iscoroutine(result):
441
- await result
479
+ # Serialize dataframes correctly, either direct or as a DataResponse
480
+ if isinstance(res, DataFrame):
481
+ return Response(df_to_json(res), media_type='application/json')
482
+ elif is_data_response(res):
483
+ return Response(data_response_to_json(res), media_type='application/json')
442
484
 
443
- async with anyio.create_task_group() as tg:
444
- for store_uid, value in values.items():
445
- tg.start_soon(_write, store_uid, value)
485
+ return res
486
+ except KeyError as err:
487
+ raise HTTPException(status_code=404, detail=str(err)) from err
488
+ except Exception as err:
489
+ raise ValueError(f'The result for task id {task_id} could not be found') from err
446
490
 
447
- @core_api_router.get('/tasks/{task_id}', dependencies=[Depends(verify_session)])
448
- async def get_task_result(task_id: str):
449
- try:
450
- task_mgr: TaskManager = utils_registry.get('TaskManager')
451
- res = await task_mgr.get_result(task_id)
452
491
 
453
- dev_logger.debug(
454
- f'Retrieving result for Task {task_id}',
455
- 'return value',
456
- {'value': res},
457
- )
492
+ @core_api_router.delete('/tasks/{task_id}', dependencies=[Depends(verify_session)])
493
+ async def cancel_task(task_id: str):
494
+ try:
495
+ task_mgr: TaskManager = utils_registry.get('TaskManager')
496
+ return await task_mgr.cancel_task(task_id)
497
+ except TaskManagerError as e:
498
+ dev_logger.error(
499
+ f'The task id {task_id} could not be found, it may have already been cancelled',
500
+ e,
501
+ )
458
502
 
459
- # Serialize dataframes correctly, either direct or as a DataResponse
460
- if isinstance(res, DataFrame):
461
- return Response(df_to_json(res), media_type='application/json')
462
- elif is_data_response(res):
463
- return Response(data_response_to_json(res), media_type='application/json')
464
503
 
465
- return res
466
- except Exception as err:
467
- raise ValueError(f'The result for task id {task_id} could not be found') from err
504
+ @core_api_router.get('/version', dependencies=[Depends(verify_session)])
505
+ async def get_version():
506
+ return {'version': version('dara_core')}
468
507
 
469
- @core_api_router.delete('/tasks/{task_id}', dependencies=[Depends(verify_session)])
470
- async def cancel_task(task_id: str):
471
- try:
472
- task_mgr: TaskManager = utils_registry.get('TaskManager')
473
- return await task_mgr.cancel_task(task_id)
474
- except TaskManagerError as e:
475
- dev_logger.error(
476
- f'The task id {task_id} could not be found, it may have already been cancelled',
477
- e,
478
- )
479
508
 
480
- @core_api_router.get('/version', dependencies=[Depends(verify_session)])
481
- async def get_version():
482
- return {'version': version('dara_core')}
509
+ # Add the main websocket connection
510
+ core_api_router.add_api_websocket_route('/ws', ws_handler)
511
+
483
512
 
484
- # Add the main websocket connection
485
- core_api_router.add_api_websocket_route('/ws', ws_handler)
513
+ class ActionPayload(BaseModel):
514
+ uid: str
515
+ definition_uid: str
516
+ values: NormalizedPayload[Mapping[str, Any]]
517
+
518
+
519
+ class DerivedVariablePayload(BaseModel):
520
+ uid: str
521
+ values: NormalizedPayload[List[Any]]
522
+
523
+
524
+ class PyComponentPayload(BaseModel):
525
+ uid: str
526
+ name: str
527
+ values: NormalizedPayload[Mapping[str, Any]]
528
+
529
+
530
+ class RouteDataRequestBody(BaseModel):
531
+ action_payloads: List[ActionPayload] = Field(default_factory=list)
532
+ derived_variable_payloads: List[DerivedVariablePayload] = Field(default_factory=list)
533
+ py_component_payloads: List[PyComponentPayload] = Field(default_factory=list)
534
+ ws_channel: str
535
+ params: Dict[str, str] = Field(default_factory=dict)
536
+
537
+
538
+ class Result(BaseModel):
539
+ ok: bool
540
+ value: Any
541
+
542
+ @staticmethod
543
+ def success(value: Any) -> 'Result':
544
+ return Result(ok=True, value=value)
545
+
546
+ @staticmethod
547
+ def error(error: str) -> 'Result':
548
+ return Result(ok=False, value=error)
549
+
550
+
551
+ class DerivedVariableChunk(BaseModel):
552
+ type: Literal['derived_variable'] = 'derived_variable'
553
+ uid: str
554
+ result: Result
555
+
556
+
557
+ class PyComponentChunk(BaseModel):
558
+ type: Literal['py_component'] = 'py_component'
559
+ uid: str
560
+ result: Result
561
+
562
+
563
+ Chunk = Union[DerivedVariableChunk, PyComponentChunk]
564
+
565
+
566
+ def create_loader_route(config: Configuration, app: FastAPI):
567
+ route_map = config.router.to_route_map()
568
+
569
+ @app.post('/api/core/route/{route_id}', dependencies=[Depends(verify_session)])
570
+ async def get_route_data(route_id: Annotated[str, Path()], body: Annotated[RouteDataRequestBody, Body()]):
571
+ # unquote route_id since it can be url-encoded
572
+ route_id = unquote(route_id)
573
+
574
+ route_data = route_map.get(route_id)
575
+
576
+ if route_data is None:
577
+ raise HTTPException(status_code=404, detail=f'Route {route_id} not found')
578
+
579
+ action_results: Dict[str, Any] = {}
580
+
581
+ if len(body.action_payloads) > 0:
582
+ store: CacheStore = utils_registry.get('Store')
583
+ task_mgr: TaskManager = utils_registry.get('TaskManager')
584
+ registry_mgr: RegistryLookup = utils_registry.get('RegistryLookup')
585
+
586
+ WS_CHANNEL.set(body.ws_channel)
587
+
588
+ # Run actions in order to guarantee execution order
589
+ for action_payload in body.action_payloads:
590
+ action_def = await registry_mgr.get(action_registry, action_payload.definition_uid)
591
+ static_kwargs = await registry_mgr.get(static_kwargs_registry, action_payload.uid)
592
+
593
+ CURRENT_ACTION_ID.set(action_payload.uid)
594
+ values = denormalize(action_payload.values.data, action_payload.values.lookup)
595
+ try:
596
+ action_results[action_payload.uid] = await execute_action_sync(
597
+ action_def,
598
+ inp={'params': body.params, 'route': route_data.definition},
599
+ values=values,
600
+ static_kwargs=static_kwargs,
601
+ store=store,
602
+ task_mgr=task_mgr,
603
+ )
604
+ except BaseException as e:
605
+ assert route_data.definition is not None
606
+ route_path = route_data.definition.full_path
607
+ action_name = str(action_def.resolver)
608
+ raise HTTPException(
609
+ status_code=500,
610
+ detail={
611
+ 'error': str(e),
612
+ 'stacktrace': print_stacktrace(e),
613
+ 'path': route_path,
614
+ 'action_name': action_name,
615
+ },
616
+ ) from e
617
+
618
+ async def process_variables(send_stream: MemoryObjectSendStream[Chunk]):
619
+ for payload in body.derived_variable_payloads:
620
+ try:
621
+ # Run the usual DV endpoint logic
622
+ result = await get_derived_variable(
623
+ uid=payload.uid,
624
+ body=DerivedStateRequestBody(
625
+ values=payload.values,
626
+ ws_channel=body.ws_channel,
627
+ force_key=None,
628
+ ),
629
+ )
630
+ await send_stream.send(DerivedVariableChunk(uid=payload.uid, result=Result.success(result)))
631
+ except BaseException as e:
632
+ dev_logger.error(f'Error streaming derived_variable {payload.uid}', error=e)
633
+ await send_stream.send(
634
+ DerivedVariableChunk(uid=payload.uid, result=Result.error(str(e))),
635
+ )
636
+
637
+ async def process_py_components(send_stream: MemoryObjectSendStream[Chunk]):
638
+ for payload in body.py_component_payloads:
639
+ try:
640
+ result = await get_component(
641
+ component=payload.name,
642
+ body=ComponentRequestBody(
643
+ uid=payload.uid,
644
+ values=payload.values,
645
+ ws_channel=body.ws_channel,
646
+ ),
647
+ )
648
+ await send_stream.send(PyComponentChunk(uid=payload.uid, result=Result.success(result)))
649
+ except BaseException as e:
650
+ dev_logger.error(f'Error streaming py_component {payload.name}', error=e)
651
+ await send_stream.send(
652
+ PyComponentChunk(uid=payload.uid, result=Result(ok=False, value=str(e))),
653
+ )
654
+
655
+ normalized_template, lookup = normalize(jsonable_encoder(route_data.content))
656
+
657
+ # Setup the stream response
658
+ async def stream():
659
+ try:
486
660
 
487
- return core_api_router
661
+ def create_chunk(x):
662
+ return json.dumps(x) + '\r\n'
663
+
664
+ # 1. Send the template and actions
665
+ yield create_chunk(
666
+ {
667
+ 'type': 'template',
668
+ 'template': {
669
+ 'data': normalized_template,
670
+ 'lookup': lookup,
671
+ },
672
+ }
673
+ )
674
+ yield create_chunk({'type': 'actions', 'actions': jsonable_encoder(action_results)})
675
+
676
+ # 2. Optionally, if there are DVs or py_components to preload, process them in the background and stream them back as they arrive
677
+ if len(body.derived_variable_payloads) > 0 or len(body.py_component_payloads) > 0:
678
+ send_stream, receive_stream = anyio.create_memory_object_stream[Chunk](max_buffer_size=math.inf)
679
+
680
+ async def process_derived_state():
681
+ async with send_stream, anyio.create_task_group() as tg:
682
+ if len(body.derived_variable_payloads) > 0:
683
+ tg.start_soon(process_variables, send_stream)
684
+ if len(body.py_component_payloads) > 0:
685
+ tg.start_soon(process_py_components, send_stream)
686
+
687
+ async with anyio.create_task_group() as tg:
688
+ tg.start_soon(process_derived_state)
689
+
690
+ async for item in receive_stream:
691
+ yield create_chunk(jsonable_encoder(item))
692
+ except Exception as e:
693
+ traceback.print_exc()
694
+ dev_logger.error(f'Error streaming loader data for route {route_id}', error=e)
695
+
696
+ return StreamingResponse(content=stream(), media_type='application/x-ndjson')