dara-core 1.16.20a1__py3-none-any.whl → 1.16.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dara/core/__init__.py CHANGED
@@ -34,6 +34,7 @@ __version__ = version('dara-core')
34
34
 
35
35
  __all__ = [
36
36
  'action',
37
+ 'ActionCtx',
37
38
  'ConfigurationBuilder',
38
39
  'DerivedVariable',
39
40
  'DerivedDataVariable',
@@ -20,6 +20,7 @@ from __future__ import annotations
20
20
  from pydantic import BaseModel
21
21
 
22
22
  from dara.core.interactivity.actions import (
23
+ ActionCtx,
23
24
  DownloadContent,
24
25
  DownloadContentImpl,
25
26
  DownloadVariable,
@@ -45,6 +46,7 @@ from dara.core.interactivity.url_variable import UrlVariable
45
46
 
46
47
  __all__ = [
47
48
  'action',
49
+ 'ActionCtx',
48
50
  'AnyVariable',
49
51
  'AnyDataVariable',
50
52
  'DataVariable',
@@ -52,6 +52,7 @@ from dara.core.base_definitions import (
52
52
  AnnotatedAction,
53
53
  )
54
54
  from dara.core.base_definitions import DaraBaseModel as BaseModel
55
+ from dara.core.base_definitions import TaskProgressUpdate
55
56
  from dara.core.interactivity.data_variable import DataVariable
56
57
  from dara.core.internal.download import generate_download_code
57
58
  from dara.core.internal.registry_lookup import RegistryLookup
@@ -1192,6 +1193,70 @@ class ActionCtx:
1192
1193
  """
1193
1194
  return await DownloadVariable(variable=variable, file_name=file_name, type=type).execute(self)
1194
1195
 
1196
+ async def run_task(
1197
+ self,
1198
+ func: Callable,
1199
+ args: Union[List[Any], None] = None,
1200
+ kwargs: Union[Dict[str, Any], None] = None,
1201
+ on_progress: Optional[Callable[[TaskProgressUpdate], Union[None, Awaitable[None]]]] = None,
1202
+ ):
1203
+ """
1204
+ Run a calculation as a task in a separate process. Recommended for CPU intensive tasks.
1205
+ Returns the result of the task function.
1206
+
1207
+ Note that the function must be defined in a separate module as configured in `task_module` field of the
1208
+ configuration builder. This is because Dara spawns separate worker processes only designed to run
1209
+ functions from that designated module.
1210
+
1211
+ ```python
1212
+ from dara.core import ConfigurationBuilder, TaskProgressUpdate, action, ActionCtx, Variable
1213
+ from dara.components import Text, Stack, Button
1214
+ from .my_module import my_task_function
1215
+
1216
+ config = ConfigurationBuilder()
1217
+ config.task_module = 'my_module'
1218
+
1219
+ status = Variable('Not started')
1220
+
1221
+ @action
1222
+ async def my_task(ctx: ActionCtx):
1223
+ async def on_progress(update: TaskProgressUpdate):
1224
+ await ctx.update(status, f'Progress: {update.progress}% - {update.message}')
1225
+
1226
+ try:
1227
+ result = await ctx.run_task(my_task_function, args=[1, 10], on_progress=on_progress)
1228
+ await ctx.update(status, f'Result: {result}')
1229
+ except Exception as e:
1230
+ await ctx.update(status, f'Error: {e}')
1231
+
1232
+ def task_page():
1233
+ return Stack(Text('Status display:'), Text(text=status), Button('Run', onclick=my_task()))
1234
+
1235
+ config.add_page(name='task', content=task_page())
1236
+ ```
1237
+
1238
+ :param func: the function to run as a task
1239
+ :param args: the arguments to pass to the function
1240
+ :param kwargs: the keyword arguments to pass to the function
1241
+ :param on_progress: a callback to receive progress updates
1242
+ """
1243
+ from dara.core.internal.registries import utils_registry
1244
+ from dara.core.internal.tasks import Task, TaskManager
1245
+
1246
+ task_mgr: TaskManager = utils_registry.get('TaskManager')
1247
+
1248
+ task = Task(func=func, args=args, kwargs=kwargs, on_progress=on_progress)
1249
+ pending_task = await task_mgr.run_task(task)
1250
+
1251
+ # Run until completion
1252
+ await pending_task.event.wait()
1253
+
1254
+ # Raise exception if there was one
1255
+ if pending_task.error:
1256
+ raise pending_task.error
1257
+
1258
+ return pending_task.result
1259
+
1195
1260
  async def execute_action(self, action: ActionImpl):
1196
1261
  """
1197
1262
  Execute a given action.
@@ -90,7 +90,7 @@ class DerivedVariable(NonDataVariable, Generic[VariableType]):
90
90
 
91
91
  def __init__(
92
92
  self,
93
- func: Callable[..., VariableType],
93
+ func: Callable[..., VariableType] | Callable[..., Awaitable[VariableType]],
94
94
  variables: List[AnyVariable],
95
95
  cache: Optional[CacheArgType] = Cache.Type.GLOBAL,
96
96
  run_as_task: bool = False,
@@ -8,7 +8,7 @@ from .non_data_variable import NonDataVariable
8
8
  class LoopVariable(NonDataVariable):
9
9
  """
10
10
  A LoopVariable is a type of variable that represents an item in a list.
11
- It should be constructed using a parent Variable's `list_item()` method.
11
+ It should be constructed using a parent Variable's `.list_item` property.
12
12
  It should only be used in conjunction with the `For` component.
13
13
 
14
14
  By default, the entire value is used as the item and the index in the list is used as the unique key.
@@ -501,7 +501,7 @@ def create_router(config: Configuration):
501
501
  async def _write(store_uid: str, value: Any):
502
502
  WS_CHANNEL.set(ws_channel)
503
503
  store_entry: BackendStoreEntry = await registry_mgr.get(backend_store_registry, store_uid)
504
- result = store_entry.store.write(value)
504
+ result = store_entry.store.write(value, ignore_channel=ws_channel)
505
505
 
506
506
  # Backend implementation could return a coroutine
507
507
  if inspect.iscoroutine(result):
@@ -17,7 +17,7 @@ limitations under the License.
17
17
 
18
18
  import inspect
19
19
  import math
20
- from typing import Any, Callable, Dict, List, Optional, Union
20
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, Union, overload
21
21
 
22
22
  from anyio import (
23
23
  CancelScope,
@@ -70,6 +70,7 @@ class Task(BaseTask):
70
70
  notify_channels: Optional[List[str]] = None,
71
71
  cache_key: Optional[str] = None,
72
72
  task_id: Optional[str] = None,
73
+ on_progress: Optional[Callable[[TaskProgressUpdate], Union[None, Awaitable[None]]]] = None,
73
74
  ):
74
75
  """
75
76
  :param func: The function to execute within the process
@@ -87,6 +88,7 @@ class Task(BaseTask):
87
88
  self.notify_channels = notify_channels if notify_channels is not None else []
88
89
  self.cache_key = cache_key
89
90
  self.reg_entry = reg_entry
91
+ self.on_progress = on_progress
90
92
 
91
93
  super().__init__(task_id)
92
94
 
@@ -359,6 +361,14 @@ class TaskManager:
359
361
  self.ws_manager = ws_manager
360
362
  self.store = store
361
363
 
364
+ @overload
365
+ async def run_task(self, task: PendingTask, ws_channel: Optional[str] = None) -> Any:
366
+ ...
367
+
368
+ @overload
369
+ async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None) -> PendingTask:
370
+ ...
371
+
362
372
  async def run_task(self, task: BaseTask, ws_channel: Optional[str] = None):
363
373
  """
364
374
  Run a task and store it in the tasks dict
@@ -504,6 +514,8 @@ class TaskManager:
504
514
  'message': message.message,
505
515
  }
506
516
  )
517
+ if isinstance(task, Task) and task.on_progress:
518
+ await run_user_handler(task.on_progress, args=(message,))
507
519
  elif isinstance(message, TaskResult):
508
520
  # Resolve the pending task related to the result
509
521
  if message.task_id in self.tasks:
dara/core/persistence.py CHANGED
@@ -29,7 +29,6 @@ from pydantic import (
29
29
 
30
30
  from dara.core.auth.definitions import USER
31
31
  from dara.core.internal.utils import run_user_handler
32
- from dara.core.internal.websocket import WS_CHANNEL
33
32
  from dara.core.logging import dev_logger
34
33
 
35
34
  if TYPE_CHECKING:
@@ -318,7 +317,11 @@ class BackendStore(PersistenceStore):
318
317
  if not payload or len(payload) != 1:
319
318
  raise ValueError("Exactly one of 'value' or 'patches' must be provided")
320
319
 
321
- return {'store_uid': self.uid, 'sequence_number': self.sequence_number.get(scope_key, 0), **payload}
320
+ return {
321
+ 'store_uid': self.uid,
322
+ 'sequence_number': self.sequence_number.get(scope_key, 0),
323
+ **payload,
324
+ }
322
325
 
323
326
  def _get_next_sequence_number(self, key: str) -> int:
324
327
  """
@@ -330,39 +333,40 @@ class BackendStore(PersistenceStore):
330
333
  self.sequence_number[key] = current + 1
331
334
  return self.sequence_number[key]
332
335
 
333
- async def _notify_user(self, user_identifier: str, ignore_current_channel: bool = True, **payload):
336
+ async def _notify_user(self, user_identifier: str, ignore_channel: Optional[str] = None, **payload):
334
337
  """
335
338
  Notify a given user about updates to this store.
336
339
  :param user_identifier: user to notify
337
- :param ignore_current_channel: if True, ignore the current websocket channel
340
+ :param ignore_channel: if specified, ignore the specified channel
338
341
  :param payload: either value=... or patches=...
339
342
  """
340
343
  return await self.ws_mgr.send_message_to_user(
341
344
  user_identifier,
342
345
  self._create_msg(user_identifier, **payload),
343
- ignore_channel=WS_CHANNEL.get() if ignore_current_channel else None,
346
+ ignore_channel=ignore_channel,
344
347
  )
345
348
 
346
- async def _notify_global(self, ignore_current_channel: bool = True, **payload):
349
+ async def _notify_global(self, ignore_channel: Optional[str] = None, **payload):
347
350
  """
348
351
  Notify all users about updates to this store.
349
- :param ignore_current_channel: if True, ignore the current websocket channel
352
+ :param ignore_channel: if specified, ignore the specified channel
350
353
  :param payload: either value=... or patches=...
351
354
  """
352
355
  return await self.ws_mgr.broadcast(
353
356
  self._create_msg('global', **payload),
354
- ignore_channel=WS_CHANNEL.get() if ignore_current_channel else None,
357
+ ignore_channel=ignore_channel,
355
358
  )
356
359
 
357
- async def _notify_value(self, value: Any):
360
+ async def _notify_value(self, value: Any, ignore_channel: Optional[str] = None):
358
361
  """
359
362
  Notify all clients about the new value for this store.
360
363
  Broadcasts to all users if scope is global or sends to the current user if scope is user.
361
364
 
362
365
  :param value: value to notify about
366
+ :param ignore_channel: if passed, ignore the specified channel when broadcasting
363
367
  """
364
368
  if self.scope == 'global':
365
- return await self._notify_global(value=value)
369
+ return await self._notify_global(value=value, ignore_channel=ignore_channel)
366
370
 
367
371
  # For user scope, we need to find channels for the user and notify them
368
372
  user = USER.get()
@@ -371,7 +375,7 @@ class BackendStore(PersistenceStore):
371
375
  return
372
376
 
373
377
  user_identifier = user.identity_id or user.identity_name
374
- return await self._notify_user(user_identifier, value=value)
378
+ return await self._notify_user(user_identifier, value=value, ignore_channel=ignore_channel)
375
379
 
376
380
  async def _notify_patches(self, patches: List[Dict[str, Any]]):
377
381
  """
@@ -406,8 +410,8 @@ class BackendStore(PersistenceStore):
406
410
  async def _on_value(key: str, value: Any):
407
411
  # here we explicitly DON'T ignore the current channel, in case we created this variable inside e.g. a py_component we want to notify its creator as well
408
412
  if user := self._get_user(key):
409
- return await self._notify_user(user, ignore_current_channel=False, value=value)
410
- return await self._notify_global(ignore_current_channel=False, value=value)
413
+ return await self._notify_user(user, value=value)
414
+ return await self._notify_global(value=value)
411
415
 
412
416
  await self.backend.subscribe(_on_value)
413
417
 
@@ -466,7 +470,7 @@ class BackendStore(PersistenceStore):
466
470
 
467
471
  return updated_value
468
472
 
469
- async def write(self, value: Any, notify=True):
473
+ async def write(self, value: Any, notify=True, ignore_channel: Optional[str] = None):
470
474
  """
471
475
  Persist a value to the store.
472
476
 
@@ -475,6 +479,7 @@ class BackendStore(PersistenceStore):
475
479
 
476
480
  :param value: value to write
477
481
  :param notify: whether to broadcast the new value to clients
482
+ :param ignore_channel: if passed, ignore the specified websocket channel when broadcasting
478
483
  """
479
484
  if self.readonly:
480
485
  raise ValueError('Cannot write to a read-only store')
@@ -486,7 +491,7 @@ class BackendStore(PersistenceStore):
486
491
  self._get_next_sequence_number(key)
487
492
 
488
493
  if notify:
489
- await self._notify_value(value)
494
+ await self._notify_value(value, ignore_channel=ignore_channel)
490
495
 
491
496
  return res
492
497