flwr-nightly 1.19.0.dev20250610__py3-none-any.whl → 1.19.0.dev20250612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. flwr/client/grpc_rere_client/connection.py +48 -29
  2. flwr/client/grpc_rere_client/grpc_adapter.py +8 -0
  3. flwr/client/rest_client/connection.py +138 -27
  4. flwr/common/auth_plugin/auth_plugin.py +6 -4
  5. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  6. flwr/common/inflatable.py +70 -1
  7. flwr/common/inflatable_grpc_utils.py +1 -1
  8. flwr/common/inflatable_rest_utils.py +99 -0
  9. flwr/common/serde.py +2 -0
  10. flwr/common/typing.py +5 -3
  11. flwr/proto/fleet_pb2.py +12 -16
  12. flwr/proto/fleet_pb2.pyi +4 -19
  13. flwr/proto/fleet_pb2_grpc.py +34 -0
  14. flwr/proto/fleet_pb2_grpc.pyi +13 -0
  15. flwr/proto/message_pb2.py +15 -9
  16. flwr/proto/message_pb2.pyi +41 -0
  17. flwr/proto/run_pb2.py +24 -24
  18. flwr/proto/run_pb2.pyi +4 -1
  19. flwr/proto/serverappio_pb2.py +22 -26
  20. flwr/proto/serverappio_pb2.pyi +4 -19
  21. flwr/proto/serverappio_pb2_grpc.py +34 -0
  22. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  23. flwr/server/fleet_event_log_interceptor.py +2 -2
  24. flwr/server/grid/grpc_grid.py +20 -9
  25. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -0
  26. flwr/server/superlink/fleet/message_handler/message_handler.py +33 -2
  27. flwr/server/superlink/fleet/rest_rere/rest_api.py +56 -2
  28. flwr/server/superlink/linkstate/in_memory_linkstate.py +17 -2
  29. flwr/server/superlink/linkstate/linkstate.py +6 -2
  30. flwr/server/superlink/linkstate/sqlite_linkstate.py +19 -7
  31. flwr/server/superlink/serverappio/serverappio_servicer.py +65 -29
  32. flwr/server/superlink/simulation/simulationio_servicer.py +2 -1
  33. flwr/server/superlink/utils.py +23 -10
  34. flwr/supercore/object_store/in_memory_object_store.py +160 -33
  35. flwr/supercore/object_store/object_store.py +54 -7
  36. flwr/superexec/deployment.py +6 -2
  37. flwr/superexec/exec_event_log_interceptor.py +4 -4
  38. flwr/superexec/exec_servicer.py +4 -1
  39. flwr/superexec/exec_user_auth_interceptor.py +11 -11
  40. flwr/superexec/executor.py +4 -0
  41. flwr/superexec/simulation.py +7 -1
  42. {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/METADATA +1 -1
  43. {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/RECORD +45 -44
  44. {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/WHEEL +0 -0
  45. {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/entry_points.txt +0 -0
@@ -26,6 +26,8 @@ from flwr.common.constant import SUPERLINK_NODE_ID, Status
26
26
  from flwr.common.inflatable import (
27
27
  UnexpectedObjectContentError,
28
28
  get_descendant_object_ids,
29
+ get_object_tree,
30
+ no_object_id_recompute,
29
31
  )
30
32
  from flwr.common.logger import log
31
33
  from flwr.common.serde import (
@@ -50,6 +52,8 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
50
52
  PushLogsResponse,
51
53
  )
52
54
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
55
+ ConfirmMessageReceivedRequest,
56
+ ConfirmMessageReceivedResponse,
53
57
  ObjectIDs,
54
58
  PullObjectRequest,
55
59
  PullObjectResponse,
@@ -107,14 +111,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
107
111
  """Get available nodes."""
108
112
  log(DEBUG, "ServerAppIoServicer.GetNodes")
109
113
 
110
- # Init state
111
- state: LinkState = self.state_factory.state()
114
+ # Init state and store
115
+ state = self.state_factory.state()
116
+ store = self.objectstore_factory.store()
112
117
 
113
118
  # Abort if the run is not running
114
119
  abort_if(
115
120
  request.run_id,
116
121
  [Status.PENDING, Status.STARTING, Status.FINISHED],
117
122
  state,
123
+ store,
118
124
  context,
119
125
  )
120
126
 
@@ -128,14 +134,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
128
134
  """Push a set of Messages."""
129
135
  log(DEBUG, "ServerAppIoServicer.PushMessages")
130
136
 
131
- # Init state
132
- state: LinkState = self.state_factory.state()
137
+ # Init state and store
138
+ state = self.state_factory.state()
139
+ store = self.objectstore_factory.store()
133
140
 
134
141
  # Abort if the run is not running
135
142
  abort_if(
136
143
  request.run_id,
137
144
  [Status.PENDING, Status.STARTING, Status.FINISHED],
138
145
  state,
146
+ store,
139
147
  context,
140
148
  )
141
149
 
@@ -146,8 +154,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
146
154
  detail="`messages_list` must not be empty",
147
155
  )
148
156
  message_ids: list[Optional[str]] = []
149
- while request.messages_list:
150
- message_proto = request.messages_list.pop(0)
157
+ for message_proto in request.messages_list:
151
158
  message = message_from_proto(message_proto=message_proto)
152
159
  validation_errors = validate_message(message, is_reply_message=False)
153
160
  _raise_if(
@@ -164,9 +171,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
164
171
  message_id: Optional[str] = state.store_message_ins(message=message)
165
172
  message_ids.append(message_id)
166
173
 
167
- # Init store
168
- store = self.objectstore_factory.store()
169
-
170
174
  # Store Message object to descendants mapping and preregister objects
171
175
  objects_to_push = store_mapping_and_register_objects(store, request=request)
172
176
 
@@ -183,10 +187,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
183
187
  """Pull a set of Messages."""
184
188
  log(DEBUG, "ServerAppIoServicer.PullMessages")
185
189
 
186
- # Init state
187
- state: LinkState = self.state_factory.state()
188
-
189
- # Init store
190
+ # Init state and store
191
+ state = self.state_factory.state()
190
192
  store = self.objectstore_factory.store()
191
193
 
192
194
  # Abort if the run is not running
@@ -194,6 +196,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
194
196
  request.run_id,
195
197
  [Status.PENDING, Status.STARTING, Status.FINISHED],
196
198
  state,
199
+ store,
197
200
  context,
198
201
  )
199
202
 
@@ -205,14 +208,15 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
205
208
  # Register messages generated by LinkState in the Store for consistency
206
209
  for msg_res in messages_res:
207
210
  if msg_res.metadata.src_node_id == SUPERLINK_NODE_ID:
208
- descendants = list(get_descendant_object_ids(msg_res))
209
- message_obj_id = msg_res.metadata.message_id
211
+ with no_object_id_recompute():
212
+ descendants = list(get_descendant_object_ids(msg_res))
213
+ message_obj_id = msg_res.metadata.message_id
210
214
  # Store mapping
211
215
  store.set_message_descendant_ids(
212
216
  msg_object_id=message_obj_id, descendant_ids=descendants
213
217
  )
214
218
  # Preregister
215
- store.preregister(descendants + [message_obj_id])
219
+ store.preregister(request.run_id, get_object_tree(msg_res))
216
220
 
217
221
  # Delete the instruction Messages and their replies if found
218
222
  message_ins_ids_to_delete = {
@@ -328,14 +332,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
328
332
  """Push ServerApp process outputs."""
329
333
  log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
330
334
 
331
- # Init state
335
+ # Init state and store
332
336
  state = self.state_factory.state()
337
+ store = self.objectstore_factory.store()
333
338
 
334
339
  # Abort if the run is not running
335
340
  abort_if(
336
341
  request.run_id,
337
342
  [Status.PENDING, Status.STARTING, Status.FINISHED],
338
343
  state,
344
+ store,
339
345
  context,
340
346
  )
341
347
 
@@ -348,16 +354,23 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
348
354
  """Update the status of a run."""
349
355
  log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
350
356
 
351
- # Init state
357
+ # Init state and store
352
358
  state = self.state_factory.state()
359
+ store = self.objectstore_factory.store()
353
360
 
354
361
  # Abort if the run is finished
355
- abort_if(request.run_id, [Status.FINISHED], state, context)
362
+ abort_if(request.run_id, [Status.FINISHED], state, store, context)
356
363
 
357
364
  # Update the run status
358
365
  state.update_run_status(
359
366
  run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
360
367
  )
368
+
369
+ # If the run is finished, delete the run from ObjectStore
370
+ if request.run_status.status == Status.FINISHED:
371
+ # Delete all objects related to the run
372
+ store.delete_objects_in_run(request.run_id)
373
+
361
374
  return UpdateRunStatusResponse()
362
375
 
363
376
  def PushLogs(
@@ -412,14 +425,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
412
425
  """Push an object to the ObjectStore."""
413
426
  log(DEBUG, "ServerAppIoServicer.PushObject")
414
427
 
415
- # Init state
416
- state: LinkState = self.state_factory.state()
428
+ # Init state and store
429
+ state = self.state_factory.state()
430
+ store = self.objectstore_factory.store()
417
431
 
418
432
  # Abort if the run is not running
419
433
  abort_if(
420
434
  request.run_id,
421
435
  [Status.PENDING, Status.STARTING, Status.FINISHED],
422
436
  state,
437
+ store,
423
438
  context,
424
439
  )
425
440
 
@@ -427,9 +442,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
427
442
  # Cancel insertion in ObjectStore
428
443
  context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
429
444
 
430
- # Init store
431
- store = self.objectstore_factory.store()
432
-
433
445
  # Insert in store
434
446
  stored = False
435
447
  try:
@@ -449,14 +461,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
449
461
  """Pull an object from the ObjectStore."""
450
462
  log(DEBUG, "ServerAppIoServicer.PullObject")
451
463
 
452
- # Init state
453
- state: LinkState = self.state_factory.state()
464
+ # Init state and store
465
+ state = self.state_factory.state()
466
+ store = self.objectstore_factory.store()
454
467
 
455
468
  # Abort if the run is not running
456
469
  abort_if(
457
470
  request.run_id,
458
471
  [Status.PENDING, Status.STARTING, Status.FINISHED],
459
472
  state,
473
+ store,
460
474
  context,
461
475
  )
462
476
 
@@ -464,9 +478,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
464
478
  # Cancel insertion in ObjectStore
465
479
  context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
466
480
 
467
- # Init store
468
- store = self.objectstore_factory.store()
469
-
470
481
  # Fetch from store
471
482
  content = store.get(request.object_id)
472
483
  if content is not None:
@@ -478,6 +489,31 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
478
489
  )
479
490
  return PullObjectResponse(object_found=False, object_available=False)
480
491
 
492
+ def ConfirmMessageReceived(
493
+ self, request: ConfirmMessageReceivedRequest, context: grpc.ServicerContext
494
+ ) -> ConfirmMessageReceivedResponse:
495
+ """Confirm message received."""
496
+ log(DEBUG, "ServerAppIoServicer.ConfirmMessageReceived")
497
+
498
+ # Init state and store
499
+ state = self.state_factory.state()
500
+ store = self.objectstore_factory.store()
501
+
502
+ # Abort if the run is not running
503
+ abort_if(
504
+ request.run_id,
505
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
506
+ state,
507
+ store,
508
+ context,
509
+ )
510
+
511
+ # Delete the message object
512
+ store.delete(request.message_object_id)
513
+ store.delete_message_descendant_ids(request.message_object_id)
514
+
515
+ return ConfirmMessageReceivedResponse()
516
+
481
517
 
482
518
  def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
483
519
  """Raise a `ValueError` with a detailed message if a validation error occurs."""
@@ -121,6 +121,7 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
121
121
  request.run_id,
122
122
  [Status.PENDING, Status.STARTING, Status.FINISHED],
123
123
  state,
124
+ None,
124
125
  context,
125
126
  )
126
127
 
@@ -135,7 +136,7 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
135
136
  state = self.state_factory.state()
136
137
 
137
138
  # Abort if the run is finished
138
- abort_if(request.run_id, [Status.FINISHED], state, context)
139
+ abort_if(request.run_id, [Status.FINISHED], state, None, context)
139
140
 
140
141
  # Update the run status
141
142
  state.update_run_status(
@@ -15,11 +15,12 @@
15
15
  """SuperLink utilities."""
16
16
 
17
17
 
18
- from typing import Union
18
+ from typing import Optional, Union
19
19
 
20
20
  import grpc
21
21
 
22
22
  from flwr.common.constant import Status, SubStatus
23
+ from flwr.common.inflatable import iterate_object_tree
23
24
  from flwr.common.typing import RunStatus
24
25
  from flwr.proto.fleet_pb2 import PushMessagesRequest # pylint: disable=E0611
25
26
  from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
@@ -39,6 +40,7 @@ def check_abort(
39
40
  run_id: int,
40
41
  abort_status_list: list[str],
41
42
  state: LinkState,
43
+ store: Optional[ObjectStore] = None,
42
44
  ) -> Union[str, None]:
43
45
  """Check if the status of the provided `run_id` is in `abort_status_list`."""
44
46
  run_status: RunStatus = state.get_run_status({run_id})[run_id]
@@ -49,6 +51,10 @@ def check_abort(
49
51
  msg += " Stopped by user."
50
52
  return msg
51
53
 
54
+ # Clear the objects of the run from the store if the run is finished
55
+ if store and run_status.status == Status.FINISHED:
56
+ store.delete_objects_in_run(run_id)
57
+
52
58
  return None
53
59
 
54
60
 
@@ -62,10 +68,11 @@ def abort_if(
62
68
  run_id: int,
63
69
  abort_status_list: list[str],
64
70
  state: LinkState,
71
+ store: Optional[ObjectStore],
65
72
  context: grpc.ServicerContext,
66
73
  ) -> None:
67
74
  """Abort context if status of the provided `run_id` is in `abort_status_list`."""
68
- msg = check_abort(run_id, abort_status_list, state)
75
+ msg = check_abort(run_id, abort_status_list, state, store)
69
76
  abort_grpc_context(msg, context)
70
77
 
71
78
 
@@ -73,21 +80,27 @@ def store_mapping_and_register_objects(
73
80
  store: ObjectStore, request: Union[PushInsMessagesRequest, PushMessagesRequest]
74
81
  ) -> dict[str, ObjectIDs]:
75
82
  """Store Message object to descendants mapping and preregister objects."""
83
+ if not request.messages_list:
84
+ return {}
85
+
76
86
  objects_to_push: dict[str, ObjectIDs] = {}
77
- for (
78
- message_obj_id,
79
- descendant_obj_ids,
80
- ) in request.msg_to_descendant_mapping.items():
81
- descendants = list(descendant_obj_ids.object_ids)
87
+
88
+ # Get run_id from the first message in the list
89
+ # All messages of a request should in the same run
90
+ run_id = request.messages_list[0].metadata.run_id
91
+
92
+ for object_tree in request.message_object_trees:
93
+ all_object_ids = [obj.object_id for obj in iterate_object_tree(object_tree)]
94
+ msg_object_id, descendant_ids = all_object_ids[-1], all_object_ids[:-1]
82
95
  # Store mapping
83
96
  store.set_message_descendant_ids(
84
- msg_object_id=message_obj_id, descendant_ids=descendants
97
+ msg_object_id=msg_object_id, descendant_ids=descendant_ids
85
98
  )
86
99
 
87
100
  # Preregister
88
- object_ids_just_registered = store.preregister(descendants + [message_obj_id])
101
+ object_ids_just_registered = store.preregister(run_id, object_tree)
89
102
  # Keep track of objects that need to be pushed
90
- objects_to_push[message_obj_id] = ObjectIDs(
103
+ objects_to_push[msg_object_id] = ObjectIDs(
91
104
  object_ids=object_ids_just_registered
92
105
  )
93
106
 
@@ -15,44 +15,95 @@
15
15
  """Flower in-memory ObjectStore implementation."""
16
16
 
17
17
 
18
+ import threading
19
+ from dataclasses import dataclass
18
20
  from typing import Optional
19
21
 
20
- from flwr.common.inflatable import get_object_id, is_valid_sha256_hash
22
+ from flwr.common.inflatable import (
23
+ get_object_children_ids_from_object_content,
24
+ get_object_id,
25
+ is_valid_sha256_hash,
26
+ iterate_object_tree,
27
+ )
21
28
  from flwr.common.inflatable_utils import validate_object_content
29
+ from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
22
30
 
23
31
  from .object_store import NoObjectInStoreError, ObjectStore
24
32
 
25
33
 
34
+ @dataclass
35
+ class ObjectEntry:
36
+ """Data class representing an object entry in the store."""
37
+
38
+ content: bytes
39
+ is_available: bool
40
+ ref_count: int # Number of references (direct parents) to this object
41
+ runs: set[int] # Set of run IDs that used this object
42
+
43
+
26
44
  class InMemoryObjectStore(ObjectStore):
27
45
  """In-memory implementation of the ObjectStore interface."""
28
46
 
29
47
  def __init__(self, verify: bool = True) -> None:
30
48
  self.verify = verify
31
- self.store: dict[str, bytes] = {}
32
- # Mapping the Object ID of a message to the list of children object IDs
33
- self.msg_children_objects_mapping: dict[str, list[str]] = {}
34
-
35
- def preregister(self, object_ids: list[str]) -> list[str]:
49
+ self.store: dict[str, ObjectEntry] = {}
50
+ self.lock_store = threading.RLock()
51
+ # Mapping the Object ID of a message to the list of descendant object IDs
52
+ self.msg_descendant_objects_mapping: dict[str, list[str]] = {}
53
+ self.lock_msg_mapping = threading.RLock()
54
+ # Mapping each run ID to a set of object IDs that are used in that run
55
+ self.run_objects_mapping: dict[int, set[str]] = {}
56
+
57
+ def preregister(self, run_id: int, object_tree: ObjectTree) -> list[str]:
36
58
  """Identify and preregister missing objects."""
37
59
  new_objects = []
38
- for obj_id in object_ids:
60
+ if run_id not in self.run_objects_mapping:
61
+ self.run_objects_mapping[run_id] = set()
62
+
63
+ for tree_node in iterate_object_tree(object_tree):
64
+ obj_id = tree_node.object_id
39
65
  # Verify object ID format (must be a valid sha256 hash)
40
66
  if not is_valid_sha256_hash(obj_id):
41
67
  raise ValueError(f"Invalid object ID format: {obj_id}")
42
- if obj_id not in self.store:
43
- self.store[obj_id] = b""
44
- new_objects.append(obj_id)
68
+ with self.lock_store:
69
+ if obj_id not in self.store:
70
+ self.store[obj_id] = ObjectEntry(
71
+ content=b"", # Initially empty content
72
+ is_available=False, # Initially not available
73
+ ref_count=0, # Reference count starts at 0
74
+ runs={run_id}, # Start with the current run ID
75
+ )
76
+
77
+ # Increment the reference count for all its children
78
+ # Post-order traversal ensures that children are registered
79
+ # before parents
80
+ for child_node in tree_node.children:
81
+ child_id = child_node.object_id
82
+ self.store[child_id].ref_count += 1
83
+
84
+ # Add the object ID to the run's mapping
85
+ self.run_objects_mapping[run_id].add(obj_id)
86
+
87
+ # Add to the list of new objects
88
+ new_objects.append(obj_id)
89
+ else:
90
+ # Object is in store, retrieve it
91
+ obj_entry = self.store[obj_id]
92
+
93
+ # Add to the list of new objects if not available
94
+ if not obj_entry.is_available:
95
+ new_objects.append(obj_id)
96
+
97
+ # If the object is already registered but not in this run,
98
+ # add the run ID to its runs
99
+ if obj_id not in self.run_objects_mapping[run_id]:
100
+ obj_entry.runs.add(run_id)
101
+ self.run_objects_mapping[run_id].add(obj_id)
45
102
 
46
103
  return new_objects
47
104
 
48
105
  def put(self, object_id: str, object_content: bytes) -> None:
49
106
  """Put an object into the store."""
50
- # Only allow adding the object if it has been preregistered
51
- if object_id not in self.store:
52
- raise NoObjectInStoreError(
53
- f"Object with ID '{object_id}' was not pre-registered."
54
- )
55
-
56
107
  if self.verify:
57
108
  # Verify object_id and object_content match
58
109
  object_id_from_content = get_object_id(object_content)
@@ -62,41 +113,117 @@ class InMemoryObjectStore(ObjectStore):
62
113
  # Validate object content
63
114
  validate_object_content(content=object_content)
64
115
 
65
- # Return if object is already present in the store
66
- if self.store[object_id] != b"":
67
- return
116
+ with self.lock_store:
117
+ # Only allow adding the object if it has been preregistered
118
+ if object_id not in self.store:
119
+ raise NoObjectInStoreError(
120
+ f"Object with ID '{object_id}' was not pre-registered."
121
+ )
122
+
123
+ # Return if object is already present in the store
124
+ if self.store[object_id].is_available:
125
+ return
68
126
 
69
- self.store[object_id] = object_content
127
+ # Update the object entry in the store
128
+ self.store[object_id].content = object_content
129
+ self.store[object_id].is_available = True
70
130
 
71
131
  def set_message_descendant_ids(
72
132
  self, msg_object_id: str, descendant_ids: list[str]
73
133
  ) -> None:
74
134
  """Store the mapping from a ``Message`` object ID to the object IDs of its
75
135
  descendants."""
76
- self.msg_children_objects_mapping[msg_object_id] = descendant_ids
136
+ with self.lock_msg_mapping:
137
+ self.msg_descendant_objects_mapping[msg_object_id] = descendant_ids
77
138
 
78
139
  def get_message_descendant_ids(self, msg_object_id: str) -> list[str]:
79
140
  """Retrieve the object IDs of all descendants of a given Message."""
80
- if msg_object_id not in self.msg_children_objects_mapping:
81
- raise NoObjectInStoreError(
82
- f"No message registered in Object Store with ID '{msg_object_id}'. "
83
- "Mapping to descendants could not be found."
84
- )
85
- return self.msg_children_objects_mapping[msg_object_id]
141
+ with self.lock_msg_mapping:
142
+ if msg_object_id not in self.msg_descendant_objects_mapping:
143
+ raise NoObjectInStoreError(
144
+ f"No message registered in Object Store with ID '{msg_object_id}'. "
145
+ "Mapping to descendants could not be found."
146
+ )
147
+ return self.msg_descendant_objects_mapping[msg_object_id]
148
+
149
+ def delete_message_descendant_ids(self, msg_object_id: str) -> None:
150
+ """Delete the mapping from a ``Message`` object ID to its descendants."""
151
+ with self.lock_msg_mapping:
152
+ self.msg_descendant_objects_mapping.pop(msg_object_id, None)
86
153
 
87
154
  def get(self, object_id: str) -> Optional[bytes]:
88
155
  """Get an object from the store."""
89
- return self.store.get(object_id)
156
+ with self.lock_store:
157
+ # Check if the object ID is pre-registered
158
+ if object_id not in self.store:
159
+ return None
160
+
161
+ # Return content (if not yet available, it will b"")
162
+ return self.store[object_id].content
90
163
 
91
164
  def delete(self, object_id: str) -> None:
92
- """Delete an object from the store."""
93
- if object_id in self.store:
94
- del self.store[object_id]
165
+ """Delete an object and its unreferenced descendants from the store."""
166
+ with self.lock_store:
167
+ # If the object is not in the store, nothing to delete
168
+ if (object_entry := self.store.get(object_id)) is None:
169
+ return
170
+
171
+ # Delete the object if it has no references left
172
+ if object_entry.ref_count == 0:
173
+ del self.store[object_id]
174
+
175
+ # Remove the object from the run's mapping
176
+ for run_id in object_entry.runs:
177
+ self.run_objects_mapping[run_id].discard(object_id)
178
+
179
+ # Decrease the reference count of its children
180
+ children_ids = get_object_children_ids_from_object_content(
181
+ object_entry.content
182
+ )
183
+ for child_id in children_ids:
184
+ self.store[child_id].ref_count -= 1
185
+
186
+ # Recursively try to delete the child object
187
+ self.delete(child_id)
188
+
189
+ def delete_objects_in_run(self, run_id: int) -> None:
190
+ """Delete all objects that were registered in a specific run."""
191
+ with self.lock_store:
192
+ if run_id not in self.run_objects_mapping:
193
+ return
194
+ for object_id in list(self.run_objects_mapping[run_id]):
195
+ # Check if the object is still in the store
196
+ if (object_entry := self.store.get(object_id)) is None:
197
+ continue
198
+
199
+ # Remove the run ID from the object's runs
200
+ object_entry.runs.discard(run_id)
201
+
202
+ # Only message objects are allowed to have a `ref_count` of 0,
203
+ # and every message object must have a `ref_count` of 0
204
+ if object_entry.ref_count == 0:
205
+ # Delete the message object and its unreferenced descendants
206
+ self.delete(object_id)
207
+
208
+ # Delete the message's descendants mapping
209
+ self.delete_message_descendant_ids(object_id)
210
+
211
+ # Remove the run from the mapping
212
+ del self.run_objects_mapping[run_id]
95
213
 
96
214
  def clear(self) -> None:
97
215
  """Clear the store."""
98
- self.store.clear()
216
+ with self.lock_store:
217
+ self.store.clear()
218
+ self.msg_descendant_objects_mapping.clear()
219
+ self.run_objects_mapping.clear()
99
220
 
100
221
  def __contains__(self, object_id: str) -> bool:
101
222
  """Check if an object_id is in the store."""
102
- return object_id in self.store
223
+ with self.lock_store:
224
+ return object_id in self.store
225
+
226
+ def __len__(self) -> int:
227
+ """Get the number of objects in the store."""
228
+ with self.lock_store:
229
+ return len(self.store)