dara-core 1.16.17__py3-none-any.whl → 1.16.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -84,7 +84,7 @@ class DerivedVariable(NonDataVariable, Generic[VariableType]):
84
84
  variables: List[AnyVariable]
85
85
  polling_interval: Optional[int]
86
86
  deps: Optional[List[AnyVariable]] = Field(validate_default=True)
87
- nested: List[str] = []
87
+ nested: List[str] = Field(default_factory=list)
88
88
  uid: str
89
89
  model_config = ConfigDict(extra='forbid', use_enum_values=True)
90
90
 
@@ -97,6 +97,7 @@ class DerivedVariable(NonDataVariable, Generic[VariableType]):
97
97
  polling_interval: Optional[int] = None,
98
98
  deps: Optional[List[AnyVariable]] = None,
99
99
  uid: Optional[str] = None,
100
+ nested: Optional[List[str]] = None,
100
101
  _get_value: Optional[Callable[..., Awaitable[Any]]] = None,
101
102
  ):
102
103
  """
@@ -124,6 +125,9 @@ class DerivedVariable(NonDataVariable, Generic[VariableType]):
124
125
  - `deps = [var1.get('nested_property')]` - `func` is ran only when the nested property changes, other changes to the variable are ignored
125
126
  :param uid: the unique identifier for this variable; if not provided a random one is generated
126
127
  """
128
+ if nested is None:
129
+ nested = []
130
+
127
131
  if cache is not None:
128
132
  cache = Cache.Policy.from_arg(cache)
129
133
 
@@ -141,7 +145,9 @@ class DerivedVariable(NonDataVariable, Generic[VariableType]):
141
145
  if get_ipython() is not None:
142
146
  raise RuntimeError('run_as_task is not supported within a Jupyter environment')
143
147
 
144
- super().__init__(cache=cache, uid=uid, variables=variables, polling_interval=polling_interval, deps=deps)
148
+ super().__init__(
149
+ cache=cache, uid=uid, variables=variables, polling_interval=polling_interval, deps=deps, nested=nested
150
+ )
145
151
 
146
152
  # Import the registry of variables and register the function at import
147
153
  from dara.core.internal.registries import derived_variable_registry
@@ -24,6 +24,7 @@ from typing import Any, Callable, Generic, List, Optional, TypeVar
24
24
  from fastapi.encoders import jsonable_encoder
25
25
  from pydantic import (
26
26
  ConfigDict,
27
+ Field,
27
28
  SerializerFunctionWrapHandler,
28
29
  field_serializer,
29
30
  model_serializer,
@@ -65,7 +66,7 @@ class Variable(NonDataVariable, Generic[VariableType]):
65
66
  persist_value: bool = False
66
67
  store: Optional[PersistenceStore] = None
67
68
  uid: str
68
- nested: List[str] = []
69
+ nested: List[str] = Field(default_factory=list)
69
70
  model_config = ConfigDict(extra='forbid')
70
71
 
71
72
  def __init__(
@@ -86,7 +87,7 @@ class Variable(NonDataVariable, Generic[VariableType]):
86
87
  """
87
88
  if nested is None:
88
89
  nested = []
89
- kwargs = {'default': default, 'persist_value': persist_value, 'uid': uid, 'store': store}
90
+ kwargs = {'default': default, 'persist_value': persist_value, 'uid': uid, 'store': store, 'nested': nested}
90
91
 
91
92
  # If an override is active, run the kwargs through it
92
93
  override = VARIABLE_INIT_OVERRIDE.get()
@@ -174,7 +175,7 @@ class Variable(NonDataVariable, Generic[VariableType]):
174
175
  :param key: the key to access; must be a string
175
176
  ```
176
177
  """
177
- return self.copy(update={'nested': [*self.nested, key]}, deep=True)
178
+ return self.model_copy(update={'nested': [*self.nested, key]}, deep=True)
178
179
 
179
180
  def sync(self):
180
181
  """
@@ -487,7 +487,12 @@ def create_router(config: Configuration):
487
487
  if inspect.iscoroutine(result):
488
488
  result = await result
489
489
 
490
- return result
490
+ # Get the current key and sequence number for this store
491
+ store = store_entry.store
492
+ key = await store._get_key()
493
+ sequence_number = store.sequence_number.get(key, 0)
494
+
495
+ return {'value': result, 'sequence_number': sequence_number}
491
496
 
492
497
  @core_api_router.post('/store', dependencies=[Depends(verify_session)])
493
498
  async def sync_backend_store(ws_channel: str = Body(), values: Dict[str, Any] = Body()):
dara/core/persistence.py CHANGED
@@ -1,11 +1,23 @@
1
1
  import abc
2
2
  import json
3
3
  import os
4
- from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Literal, Optional, Set
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Any,
7
+ Awaitable,
8
+ Callable,
9
+ Dict,
10
+ List,
11
+ Literal,
12
+ Optional,
13
+ Set,
14
+ Union,
15
+ )
5
16
  from uuid import uuid4
6
17
 
7
18
  import aiorwlock
8
19
  import anyio
20
+ import jsonpatch
9
21
  from pydantic import (
10
22
  BaseModel,
11
23
  Field,
@@ -189,6 +201,9 @@ class BackendStore(PersistenceStore):
189
201
 
190
202
  default_value: Any = Field(default=None, exclude=True)
191
203
  initialized_scopes: Set[str] = Field(default_factory=set, exclude=True)
204
+ sequence_number: Dict[str, int] = Field(
205
+ default_factory=dict, exclude=True
206
+ ) # Track sequence numbers per user for patch validation
192
207
 
193
208
  def __init__(
194
209
  self,
@@ -233,6 +248,8 @@ class BackendStore(PersistenceStore):
233
248
  self.initialized_scopes.add('global')
234
249
  if not await run_user_handler(self.backend.has, args=(key,)):
235
250
  await run_user_handler(self.backend.write, (key, self.default_value))
251
+ # Initialize sequence number for this key
252
+ self.sequence_number[key] = 0
236
253
 
237
254
  return key
238
255
 
@@ -246,6 +263,8 @@ class BackendStore(PersistenceStore):
246
263
  self.initialized_scopes.add(user_key)
247
264
  if not await run_user_handler(self.backend.has, args=(user_key,)):
248
265
  await run_user_handler(self.backend.write, (user_key, self.default_value))
266
+ # Initialize sequence number for this key
267
+ self.sequence_number[user_key] = 0
249
268
 
250
269
  return user_key
251
270
 
@@ -290,36 +309,48 @@ class BackendStore(PersistenceStore):
290
309
 
291
310
  return utils_registry.get('WebsocketManager')
292
311
 
293
- def _create_msg(self, value: Any) -> Dict[str, Any]:
312
+ def _create_msg(self, scope_key: str, **payload) -> Dict[str, Any]:
294
313
  """
295
314
  Create a message to send to the frontend.
296
- :param value: value to send
315
+ :param scope_key: scope key for sequence number
316
+ :param payload: either value=... or patches=...
297
317
  """
298
- return {'store_uid': self.uid, 'value': value}
318
+ if not payload or len(payload) != 1:
319
+ raise ValueError("Exactly one of 'value' or 'patches' must be provided")
320
+
321
+ return {'store_uid': self.uid, 'sequence_number': self.sequence_number.get(scope_key, 0), **payload}
299
322
 
300
- async def _notify_user(self, user_identifier: str, value: Any, ignore_current_channel: bool = True):
323
+ def _get_next_sequence_number(self, key: str) -> int:
301
324
  """
302
- Notify a given user about the new value for this store.
325
+ Get the next sequence number for this store.
303
326
 
327
+ :param key: key for the store
328
+ """
329
+ current = self.sequence_number.get(key, 0)
330
+ self.sequence_number[key] = current + 1
331
+ return self.sequence_number[key]
332
+
333
+ async def _notify_user(self, user_identifier: str, ignore_current_channel: bool = True, **payload):
334
+ """
335
+ Notify a given user about updates to this store.
304
336
  :param user_identifier: user to notify
305
- :param value: value to notify about
306
337
  :param ignore_current_channel: if True, ignore the current websocket channel
338
+ :param payload: either value=... or patches=...
307
339
  """
308
340
  return await self.ws_mgr.send_message_to_user(
309
341
  user_identifier,
310
- self._create_msg(value),
342
+ self._create_msg(user_identifier, **payload),
311
343
  ignore_channel=WS_CHANNEL.get() if ignore_current_channel else None,
312
344
  )
313
345
 
314
- async def _notify_global(self, value: Any, ignore_current_channel: bool = True):
346
+ async def _notify_global(self, ignore_current_channel: bool = True, **payload):
315
347
  """
316
- Notify all users about the new value for this store.
317
-
318
- :param value: value to notify about
348
+ Notify all users about updates to this store.
319
349
  :param ignore_current_channel: if True, ignore the current websocket channel
350
+ :param payload: either value=... or patches=...
320
351
  """
321
352
  return await self.ws_mgr.broadcast(
322
- self._create_msg(value),
353
+ self._create_msg('global', **payload),
323
354
  ignore_channel=WS_CHANNEL.get() if ignore_current_channel else None,
324
355
  )
325
356
 
@@ -331,7 +362,7 @@ class BackendStore(PersistenceStore):
331
362
  :param value: value to notify about
332
363
  """
333
364
  if self.scope == 'global':
334
- return await self._notify_global(value)
365
+ return await self._notify_global(value=value)
335
366
 
336
367
  # For user scope, we need to find channels for the user and notify them
337
368
  user = USER.get()
@@ -340,7 +371,26 @@ class BackendStore(PersistenceStore):
340
371
  return
341
372
 
342
373
  user_identifier = user.identity_id or user.identity_name
343
- return await self._notify_user(user_identifier, value)
374
+ return await self._notify_user(user_identifier, value=value)
375
+
376
+ async def _notify_patches(self, patches: List[Dict[str, Any]]):
377
+ """
378
+ Notify all clients about partial updates to this store.
379
+ Broadcasts to all users if scope is global or sends to the current user if scope is user.
380
+
381
+ :param patches: list of JSON patch operations
382
+ """
383
+ if self.scope == 'global':
384
+ return await self._notify_global(patches=patches)
385
+
386
+ # For user scope, we need to find channels for the user and notify them
387
+ user = USER.get()
388
+
389
+ if not user:
390
+ return
391
+
392
+ user_identifier = user.identity_id or user.identity_name
393
+ return await self._notify_user(user_identifier, patches=patches)
344
394
 
345
395
  async def init(self, variable: 'Variable'):
346
396
  """
@@ -356,11 +406,66 @@ class BackendStore(PersistenceStore):
356
406
  async def _on_value(key: str, value: Any):
357
407
  # here we explicitly DON'T ignore the current channel, in case we created this variable inside e.g. a py_component we want to notify its creator as well
358
408
  if user := self._get_user(key):
359
- return await self._notify_user(user, value, ignore_current_channel=False)
360
- return await self._notify_global(value, ignore_current_channel=False)
409
+ return await self._notify_user(user, ignore_current_channel=False, value=value)
410
+ return await self._notify_global(ignore_current_channel=False, value=value)
361
411
 
362
412
  await self.backend.subscribe(_on_value)
363
413
 
414
+ async def write_partial(self, data: Union[List[Dict[str, Any]], Any], notify: bool = True):
415
+ """
416
+ Apply partial updates to the store using JSON Patch operations or automatic diffing.
417
+
418
+ If scope='user', the patches are applied for the current user so the method can only
419
+ be used in authenticated contexts.
420
+
421
+ :param data: Either a list of JSON patch operations (RFC 6902) or a full object to diff against current value
422
+ :param notify: whether to broadcast the patches to clients
423
+ """
424
+ if self.readonly:
425
+ raise ValueError('Cannot write to a read-only store')
426
+
427
+ key = await self._get_key()
428
+
429
+ # Read current value
430
+ current_value = await run_user_handler(self.backend.read, (key,))
431
+
432
+ if current_value is None:
433
+ # If no current value, create an empty dict as the base
434
+ current_value = {}
435
+
436
+ # Determine if data is patches or a full object
437
+ if isinstance(data, list) and all(isinstance(item, dict) and 'op' in item for item in data):
438
+ # Data is a list of patch operations
439
+ patches = data
440
+
441
+ if not isinstance(current_value, (dict, list)):
442
+ # JSON patches can only be applied to structured data (objects/arrays)
443
+ raise ValueError(
444
+ f'Cannot apply JSON patches to non-structured data. '
445
+ f'Current value is of type {type(current_value).__name__}, but patches require dict or list.'
446
+ )
447
+
448
+ # Apply patches to current value
449
+ try:
450
+ updated_value = jsonpatch.apply_patch(current_value, patches)
451
+ except (jsonpatch.InvalidJsonPatch, jsonpatch.JsonPatchException) as e:
452
+ raise ValueError(f'Invalid JSON patch operation: {e}') from e
453
+ else:
454
+ # Data is a full object - generate patches by diffing
455
+ patches = jsonpatch.make_patch(current_value, data).patch
456
+ updated_value = data
457
+
458
+ # Write updated value back to store
459
+ await run_user_handler(self.backend.write, (key, updated_value))
460
+ # Increment sequence number for this update
461
+ self._get_next_sequence_number(key)
462
+
463
+ if notify:
464
+ # Notify clients about the patches, not the full value
465
+ await self._notify_patches(patches)
466
+
467
+ return updated_value
468
+
364
469
  async def write(self, value: Any, notify=True):
365
470
  """
366
471
  Persist a value to the store.
@@ -376,10 +481,14 @@ class BackendStore(PersistenceStore):
376
481
 
377
482
  key = await self._get_key()
378
483
 
484
+ res = await run_user_handler(self.backend.write, (key, value))
485
+ # Increment sequence number for this update
486
+ self._get_next_sequence_number(key)
487
+
379
488
  if notify:
380
489
  await self._notify_value(value)
381
490
 
382
- return await run_user_handler(self.backend.write, (key, value))
491
+ return res
383
492
 
384
493
  async def read(self):
385
494
  """