flwr-nightly 1.19.0.dev20250610__py3-none-any.whl → 1.19.0.dev20250612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/grpc_rere_client/connection.py +48 -29
- flwr/client/grpc_rere_client/grpc_adapter.py +8 -0
- flwr/client/rest_client/connection.py +138 -27
- flwr/common/auth_plugin/auth_plugin.py +6 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/inflatable.py +70 -1
- flwr/common/inflatable_grpc_utils.py +1 -1
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/serde.py +2 -0
- flwr/common/typing.py +5 -3
- flwr/proto/fleet_pb2.py +12 -16
- flwr/proto/fleet_pb2.pyi +4 -19
- flwr/proto/fleet_pb2_grpc.py +34 -0
- flwr/proto/fleet_pb2_grpc.pyi +13 -0
- flwr/proto/message_pb2.py +15 -9
- flwr/proto/message_pb2.pyi +41 -0
- flwr/proto/run_pb2.py +24 -24
- flwr/proto/run_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +22 -26
- flwr/proto/serverappio_pb2.pyi +4 -19
- flwr/proto/serverappio_pb2_grpc.py +34 -0
- flwr/proto/serverappio_pb2_grpc.pyi +13 -0
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grpc_grid.py +20 -9
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +33 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +56 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +17 -2
- flwr/server/superlink/linkstate/linkstate.py +6 -2
- flwr/server/superlink/linkstate/sqlite_linkstate.py +19 -7
- flwr/server/superlink/serverappio/serverappio_servicer.py +65 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -1
- flwr/server/superlink/utils.py +23 -10
- flwr/supercore/object_store/in_memory_object_store.py +160 -33
- flwr/supercore/object_store/object_store.py +54 -7
- flwr/superexec/deployment.py +6 -2
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_servicer.py +4 -1
- flwr/superexec/exec_user_auth_interceptor.py +11 -11
- flwr/superexec/executor.py +4 -0
- flwr/superexec/simulation.py +7 -1
- {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/METADATA +1 -1
- {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/RECORD +45 -44
- {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/entry_points.txt +0 -0
@@ -26,6 +26,8 @@ from flwr.common.constant import SUPERLINK_NODE_ID, Status
|
|
26
26
|
from flwr.common.inflatable import (
|
27
27
|
UnexpectedObjectContentError,
|
28
28
|
get_descendant_object_ids,
|
29
|
+
get_object_tree,
|
30
|
+
no_object_id_recompute,
|
29
31
|
)
|
30
32
|
from flwr.common.logger import log
|
31
33
|
from flwr.common.serde import (
|
@@ -50,6 +52,8 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
|
|
50
52
|
PushLogsResponse,
|
51
53
|
)
|
52
54
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
55
|
+
ConfirmMessageReceivedRequest,
|
56
|
+
ConfirmMessageReceivedResponse,
|
53
57
|
ObjectIDs,
|
54
58
|
PullObjectRequest,
|
55
59
|
PullObjectResponse,
|
@@ -107,14 +111,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
107
111
|
"""Get available nodes."""
|
108
112
|
log(DEBUG, "ServerAppIoServicer.GetNodes")
|
109
113
|
|
110
|
-
# Init state
|
111
|
-
state
|
114
|
+
# Init state and store
|
115
|
+
state = self.state_factory.state()
|
116
|
+
store = self.objectstore_factory.store()
|
112
117
|
|
113
118
|
# Abort if the run is not running
|
114
119
|
abort_if(
|
115
120
|
request.run_id,
|
116
121
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
117
122
|
state,
|
123
|
+
store,
|
118
124
|
context,
|
119
125
|
)
|
120
126
|
|
@@ -128,14 +134,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
128
134
|
"""Push a set of Messages."""
|
129
135
|
log(DEBUG, "ServerAppIoServicer.PushMessages")
|
130
136
|
|
131
|
-
# Init state
|
132
|
-
state
|
137
|
+
# Init state and store
|
138
|
+
state = self.state_factory.state()
|
139
|
+
store = self.objectstore_factory.store()
|
133
140
|
|
134
141
|
# Abort if the run is not running
|
135
142
|
abort_if(
|
136
143
|
request.run_id,
|
137
144
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
138
145
|
state,
|
146
|
+
store,
|
139
147
|
context,
|
140
148
|
)
|
141
149
|
|
@@ -146,8 +154,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
146
154
|
detail="`messages_list` must not be empty",
|
147
155
|
)
|
148
156
|
message_ids: list[Optional[str]] = []
|
149
|
-
|
150
|
-
message_proto = request.messages_list.pop(0)
|
157
|
+
for message_proto in request.messages_list:
|
151
158
|
message = message_from_proto(message_proto=message_proto)
|
152
159
|
validation_errors = validate_message(message, is_reply_message=False)
|
153
160
|
_raise_if(
|
@@ -164,9 +171,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
164
171
|
message_id: Optional[str] = state.store_message_ins(message=message)
|
165
172
|
message_ids.append(message_id)
|
166
173
|
|
167
|
-
# Init store
|
168
|
-
store = self.objectstore_factory.store()
|
169
|
-
|
170
174
|
# Store Message object to descendants mapping and preregister objects
|
171
175
|
objects_to_push = store_mapping_and_register_objects(store, request=request)
|
172
176
|
|
@@ -183,10 +187,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
183
187
|
"""Pull a set of Messages."""
|
184
188
|
log(DEBUG, "ServerAppIoServicer.PullMessages")
|
185
189
|
|
186
|
-
# Init state
|
187
|
-
state
|
188
|
-
|
189
|
-
# Init store
|
190
|
+
# Init state and store
|
191
|
+
state = self.state_factory.state()
|
190
192
|
store = self.objectstore_factory.store()
|
191
193
|
|
192
194
|
# Abort if the run is not running
|
@@ -194,6 +196,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
194
196
|
request.run_id,
|
195
197
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
196
198
|
state,
|
199
|
+
store,
|
197
200
|
context,
|
198
201
|
)
|
199
202
|
|
@@ -205,14 +208,15 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
205
208
|
# Register messages generated by LinkState in the Store for consistency
|
206
209
|
for msg_res in messages_res:
|
207
210
|
if msg_res.metadata.src_node_id == SUPERLINK_NODE_ID:
|
208
|
-
|
209
|
-
|
211
|
+
with no_object_id_recompute():
|
212
|
+
descendants = list(get_descendant_object_ids(msg_res))
|
213
|
+
message_obj_id = msg_res.metadata.message_id
|
210
214
|
# Store mapping
|
211
215
|
store.set_message_descendant_ids(
|
212
216
|
msg_object_id=message_obj_id, descendant_ids=descendants
|
213
217
|
)
|
214
218
|
# Preregister
|
215
|
-
store.preregister(
|
219
|
+
store.preregister(request.run_id, get_object_tree(msg_res))
|
216
220
|
|
217
221
|
# Delete the instruction Messages and their replies if found
|
218
222
|
message_ins_ids_to_delete = {
|
@@ -328,14 +332,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
328
332
|
"""Push ServerApp process outputs."""
|
329
333
|
log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
|
330
334
|
|
331
|
-
# Init state
|
335
|
+
# Init state and store
|
332
336
|
state = self.state_factory.state()
|
337
|
+
store = self.objectstore_factory.store()
|
333
338
|
|
334
339
|
# Abort if the run is not running
|
335
340
|
abort_if(
|
336
341
|
request.run_id,
|
337
342
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
338
343
|
state,
|
344
|
+
store,
|
339
345
|
context,
|
340
346
|
)
|
341
347
|
|
@@ -348,16 +354,23 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
348
354
|
"""Update the status of a run."""
|
349
355
|
log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
|
350
356
|
|
351
|
-
# Init state
|
357
|
+
# Init state and store
|
352
358
|
state = self.state_factory.state()
|
359
|
+
store = self.objectstore_factory.store()
|
353
360
|
|
354
361
|
# Abort if the run is finished
|
355
|
-
abort_if(request.run_id, [Status.FINISHED], state, context)
|
362
|
+
abort_if(request.run_id, [Status.FINISHED], state, store, context)
|
356
363
|
|
357
364
|
# Update the run status
|
358
365
|
state.update_run_status(
|
359
366
|
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
|
360
367
|
)
|
368
|
+
|
369
|
+
# If the run is finished, delete the run from ObjectStore
|
370
|
+
if request.run_status.status == Status.FINISHED:
|
371
|
+
# Delete all objects related to the run
|
372
|
+
store.delete_objects_in_run(request.run_id)
|
373
|
+
|
361
374
|
return UpdateRunStatusResponse()
|
362
375
|
|
363
376
|
def PushLogs(
|
@@ -412,14 +425,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
412
425
|
"""Push an object to the ObjectStore."""
|
413
426
|
log(DEBUG, "ServerAppIoServicer.PushObject")
|
414
427
|
|
415
|
-
# Init state
|
416
|
-
state
|
428
|
+
# Init state and store
|
429
|
+
state = self.state_factory.state()
|
430
|
+
store = self.objectstore_factory.store()
|
417
431
|
|
418
432
|
# Abort if the run is not running
|
419
433
|
abort_if(
|
420
434
|
request.run_id,
|
421
435
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
422
436
|
state,
|
437
|
+
store,
|
423
438
|
context,
|
424
439
|
)
|
425
440
|
|
@@ -427,9 +442,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
427
442
|
# Cancel insertion in ObjectStore
|
428
443
|
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
|
429
444
|
|
430
|
-
# Init store
|
431
|
-
store = self.objectstore_factory.store()
|
432
|
-
|
433
445
|
# Insert in store
|
434
446
|
stored = False
|
435
447
|
try:
|
@@ -449,14 +461,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
449
461
|
"""Pull an object from the ObjectStore."""
|
450
462
|
log(DEBUG, "ServerAppIoServicer.PullObject")
|
451
463
|
|
452
|
-
# Init state
|
453
|
-
state
|
464
|
+
# Init state and store
|
465
|
+
state = self.state_factory.state()
|
466
|
+
store = self.objectstore_factory.store()
|
454
467
|
|
455
468
|
# Abort if the run is not running
|
456
469
|
abort_if(
|
457
470
|
request.run_id,
|
458
471
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
459
472
|
state,
|
473
|
+
store,
|
460
474
|
context,
|
461
475
|
)
|
462
476
|
|
@@ -464,9 +478,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
464
478
|
# Cancel insertion in ObjectStore
|
465
479
|
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
|
466
480
|
|
467
|
-
# Init store
|
468
|
-
store = self.objectstore_factory.store()
|
469
|
-
|
470
481
|
# Fetch from store
|
471
482
|
content = store.get(request.object_id)
|
472
483
|
if content is not None:
|
@@ -478,6 +489,31 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
478
489
|
)
|
479
490
|
return PullObjectResponse(object_found=False, object_available=False)
|
480
491
|
|
492
|
+
def ConfirmMessageReceived(
|
493
|
+
self, request: ConfirmMessageReceivedRequest, context: grpc.ServicerContext
|
494
|
+
) -> ConfirmMessageReceivedResponse:
|
495
|
+
"""Confirm message received."""
|
496
|
+
log(DEBUG, "ServerAppIoServicer.ConfirmMessageReceived")
|
497
|
+
|
498
|
+
# Init state and store
|
499
|
+
state = self.state_factory.state()
|
500
|
+
store = self.objectstore_factory.store()
|
501
|
+
|
502
|
+
# Abort if the run is not running
|
503
|
+
abort_if(
|
504
|
+
request.run_id,
|
505
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
506
|
+
state,
|
507
|
+
store,
|
508
|
+
context,
|
509
|
+
)
|
510
|
+
|
511
|
+
# Delete the message object
|
512
|
+
store.delete(request.message_object_id)
|
513
|
+
store.delete_message_descendant_ids(request.message_object_id)
|
514
|
+
|
515
|
+
return ConfirmMessageReceivedResponse()
|
516
|
+
|
481
517
|
|
482
518
|
def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
|
483
519
|
"""Raise a `ValueError` with a detailed message if a validation error occurs."""
|
@@ -121,6 +121,7 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
121
121
|
request.run_id,
|
122
122
|
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
123
123
|
state,
|
124
|
+
None,
|
124
125
|
context,
|
125
126
|
)
|
126
127
|
|
@@ -135,7 +136,7 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
135
136
|
state = self.state_factory.state()
|
136
137
|
|
137
138
|
# Abort if the run is finished
|
138
|
-
abort_if(request.run_id, [Status.FINISHED], state, context)
|
139
|
+
abort_if(request.run_id, [Status.FINISHED], state, None, context)
|
139
140
|
|
140
141
|
# Update the run status
|
141
142
|
state.update_run_status(
|
flwr/server/superlink/utils.py
CHANGED
@@ -15,11 +15,12 @@
|
|
15
15
|
"""SuperLink utilities."""
|
16
16
|
|
17
17
|
|
18
|
-
from typing import Union
|
18
|
+
from typing import Optional, Union
|
19
19
|
|
20
20
|
import grpc
|
21
21
|
|
22
22
|
from flwr.common.constant import Status, SubStatus
|
23
|
+
from flwr.common.inflatable import iterate_object_tree
|
23
24
|
from flwr.common.typing import RunStatus
|
24
25
|
from flwr.proto.fleet_pb2 import PushMessagesRequest # pylint: disable=E0611
|
25
26
|
from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
|
@@ -39,6 +40,7 @@ def check_abort(
|
|
39
40
|
run_id: int,
|
40
41
|
abort_status_list: list[str],
|
41
42
|
state: LinkState,
|
43
|
+
store: Optional[ObjectStore] = None,
|
42
44
|
) -> Union[str, None]:
|
43
45
|
"""Check if the status of the provided `run_id` is in `abort_status_list`."""
|
44
46
|
run_status: RunStatus = state.get_run_status({run_id})[run_id]
|
@@ -49,6 +51,10 @@ def check_abort(
|
|
49
51
|
msg += " Stopped by user."
|
50
52
|
return msg
|
51
53
|
|
54
|
+
# Clear the objects of the run from the store if the run is finished
|
55
|
+
if store and run_status.status == Status.FINISHED:
|
56
|
+
store.delete_objects_in_run(run_id)
|
57
|
+
|
52
58
|
return None
|
53
59
|
|
54
60
|
|
@@ -62,10 +68,11 @@ def abort_if(
|
|
62
68
|
run_id: int,
|
63
69
|
abort_status_list: list[str],
|
64
70
|
state: LinkState,
|
71
|
+
store: Optional[ObjectStore],
|
65
72
|
context: grpc.ServicerContext,
|
66
73
|
) -> None:
|
67
74
|
"""Abort context if status of the provided `run_id` is in `abort_status_list`."""
|
68
|
-
msg = check_abort(run_id, abort_status_list, state)
|
75
|
+
msg = check_abort(run_id, abort_status_list, state, store)
|
69
76
|
abort_grpc_context(msg, context)
|
70
77
|
|
71
78
|
|
@@ -73,21 +80,27 @@ def store_mapping_and_register_objects(
|
|
73
80
|
store: ObjectStore, request: Union[PushInsMessagesRequest, PushMessagesRequest]
|
74
81
|
) -> dict[str, ObjectIDs]:
|
75
82
|
"""Store Message object to descendants mapping and preregister objects."""
|
83
|
+
if not request.messages_list:
|
84
|
+
return {}
|
85
|
+
|
76
86
|
objects_to_push: dict[str, ObjectIDs] = {}
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
87
|
+
|
88
|
+
# Get run_id from the first message in the list
|
89
|
+
# All messages of a request should in the same run
|
90
|
+
run_id = request.messages_list[0].metadata.run_id
|
91
|
+
|
92
|
+
for object_tree in request.message_object_trees:
|
93
|
+
all_object_ids = [obj.object_id for obj in iterate_object_tree(object_tree)]
|
94
|
+
msg_object_id, descendant_ids = all_object_ids[-1], all_object_ids[:-1]
|
82
95
|
# Store mapping
|
83
96
|
store.set_message_descendant_ids(
|
84
|
-
msg_object_id=
|
97
|
+
msg_object_id=msg_object_id, descendant_ids=descendant_ids
|
85
98
|
)
|
86
99
|
|
87
100
|
# Preregister
|
88
|
-
object_ids_just_registered = store.preregister(
|
101
|
+
object_ids_just_registered = store.preregister(run_id, object_tree)
|
89
102
|
# Keep track of objects that need to be pushed
|
90
|
-
objects_to_push[
|
103
|
+
objects_to_push[msg_object_id] = ObjectIDs(
|
91
104
|
object_ids=object_ids_just_registered
|
92
105
|
)
|
93
106
|
|
@@ -15,44 +15,95 @@
|
|
15
15
|
"""Flower in-memory ObjectStore implementation."""
|
16
16
|
|
17
17
|
|
18
|
+
import threading
|
19
|
+
from dataclasses import dataclass
|
18
20
|
from typing import Optional
|
19
21
|
|
20
|
-
from flwr.common.inflatable import
|
22
|
+
from flwr.common.inflatable import (
|
23
|
+
get_object_children_ids_from_object_content,
|
24
|
+
get_object_id,
|
25
|
+
is_valid_sha256_hash,
|
26
|
+
iterate_object_tree,
|
27
|
+
)
|
21
28
|
from flwr.common.inflatable_utils import validate_object_content
|
29
|
+
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
22
30
|
|
23
31
|
from .object_store import NoObjectInStoreError, ObjectStore
|
24
32
|
|
25
33
|
|
34
|
+
@dataclass
|
35
|
+
class ObjectEntry:
|
36
|
+
"""Data class representing an object entry in the store."""
|
37
|
+
|
38
|
+
content: bytes
|
39
|
+
is_available: bool
|
40
|
+
ref_count: int # Number of references (direct parents) to this object
|
41
|
+
runs: set[int] # Set of run IDs that used this object
|
42
|
+
|
43
|
+
|
26
44
|
class InMemoryObjectStore(ObjectStore):
|
27
45
|
"""In-memory implementation of the ObjectStore interface."""
|
28
46
|
|
29
47
|
def __init__(self, verify: bool = True) -> None:
|
30
48
|
self.verify = verify
|
31
|
-
self.store: dict[str,
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
49
|
+
self.store: dict[str, ObjectEntry] = {}
|
50
|
+
self.lock_store = threading.RLock()
|
51
|
+
# Mapping the Object ID of a message to the list of descendant object IDs
|
52
|
+
self.msg_descendant_objects_mapping: dict[str, list[str]] = {}
|
53
|
+
self.lock_msg_mapping = threading.RLock()
|
54
|
+
# Mapping each run ID to a set of object IDs that are used in that run
|
55
|
+
self.run_objects_mapping: dict[int, set[str]] = {}
|
56
|
+
|
57
|
+
def preregister(self, run_id: int, object_tree: ObjectTree) -> list[str]:
|
36
58
|
"""Identify and preregister missing objects."""
|
37
59
|
new_objects = []
|
38
|
-
|
60
|
+
if run_id not in self.run_objects_mapping:
|
61
|
+
self.run_objects_mapping[run_id] = set()
|
62
|
+
|
63
|
+
for tree_node in iterate_object_tree(object_tree):
|
64
|
+
obj_id = tree_node.object_id
|
39
65
|
# Verify object ID format (must be a valid sha256 hash)
|
40
66
|
if not is_valid_sha256_hash(obj_id):
|
41
67
|
raise ValueError(f"Invalid object ID format: {obj_id}")
|
42
|
-
|
43
|
-
self.store
|
44
|
-
|
68
|
+
with self.lock_store:
|
69
|
+
if obj_id not in self.store:
|
70
|
+
self.store[obj_id] = ObjectEntry(
|
71
|
+
content=b"", # Initially empty content
|
72
|
+
is_available=False, # Initially not available
|
73
|
+
ref_count=0, # Reference count starts at 0
|
74
|
+
runs={run_id}, # Start with the current run ID
|
75
|
+
)
|
76
|
+
|
77
|
+
# Increment the reference count for all its children
|
78
|
+
# Post-order traversal ensures that children are registered
|
79
|
+
# before parents
|
80
|
+
for child_node in tree_node.children:
|
81
|
+
child_id = child_node.object_id
|
82
|
+
self.store[child_id].ref_count += 1
|
83
|
+
|
84
|
+
# Add the object ID to the run's mapping
|
85
|
+
self.run_objects_mapping[run_id].add(obj_id)
|
86
|
+
|
87
|
+
# Add to the list of new objects
|
88
|
+
new_objects.append(obj_id)
|
89
|
+
else:
|
90
|
+
# Object is in store, retrieve it
|
91
|
+
obj_entry = self.store[obj_id]
|
92
|
+
|
93
|
+
# Add to the list of new objects if not available
|
94
|
+
if not obj_entry.is_available:
|
95
|
+
new_objects.append(obj_id)
|
96
|
+
|
97
|
+
# If the object is already registered but not in this run,
|
98
|
+
# add the run ID to its runs
|
99
|
+
if obj_id not in self.run_objects_mapping[run_id]:
|
100
|
+
obj_entry.runs.add(run_id)
|
101
|
+
self.run_objects_mapping[run_id].add(obj_id)
|
45
102
|
|
46
103
|
return new_objects
|
47
104
|
|
48
105
|
def put(self, object_id: str, object_content: bytes) -> None:
|
49
106
|
"""Put an object into the store."""
|
50
|
-
# Only allow adding the object if it has been preregistered
|
51
|
-
if object_id not in self.store:
|
52
|
-
raise NoObjectInStoreError(
|
53
|
-
f"Object with ID '{object_id}' was not pre-registered."
|
54
|
-
)
|
55
|
-
|
56
107
|
if self.verify:
|
57
108
|
# Verify object_id and object_content match
|
58
109
|
object_id_from_content = get_object_id(object_content)
|
@@ -62,41 +113,117 @@ class InMemoryObjectStore(ObjectStore):
|
|
62
113
|
# Validate object content
|
63
114
|
validate_object_content(content=object_content)
|
64
115
|
|
65
|
-
|
66
|
-
|
67
|
-
|
116
|
+
with self.lock_store:
|
117
|
+
# Only allow adding the object if it has been preregistered
|
118
|
+
if object_id not in self.store:
|
119
|
+
raise NoObjectInStoreError(
|
120
|
+
f"Object with ID '{object_id}' was not pre-registered."
|
121
|
+
)
|
122
|
+
|
123
|
+
# Return if object is already present in the store
|
124
|
+
if self.store[object_id].is_available:
|
125
|
+
return
|
68
126
|
|
69
|
-
|
127
|
+
# Update the object entry in the store
|
128
|
+
self.store[object_id].content = object_content
|
129
|
+
self.store[object_id].is_available = True
|
70
130
|
|
71
131
|
def set_message_descendant_ids(
|
72
132
|
self, msg_object_id: str, descendant_ids: list[str]
|
73
133
|
) -> None:
|
74
134
|
"""Store the mapping from a ``Message`` object ID to the object IDs of its
|
75
135
|
descendants."""
|
76
|
-
self.
|
136
|
+
with self.lock_msg_mapping:
|
137
|
+
self.msg_descendant_objects_mapping[msg_object_id] = descendant_ids
|
77
138
|
|
78
139
|
def get_message_descendant_ids(self, msg_object_id: str) -> list[str]:
|
79
140
|
"""Retrieve the object IDs of all descendants of a given Message."""
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
141
|
+
with self.lock_msg_mapping:
|
142
|
+
if msg_object_id not in self.msg_descendant_objects_mapping:
|
143
|
+
raise NoObjectInStoreError(
|
144
|
+
f"No message registered in Object Store with ID '{msg_object_id}'. "
|
145
|
+
"Mapping to descendants could not be found."
|
146
|
+
)
|
147
|
+
return self.msg_descendant_objects_mapping[msg_object_id]
|
148
|
+
|
149
|
+
def delete_message_descendant_ids(self, msg_object_id: str) -> None:
|
150
|
+
"""Delete the mapping from a ``Message`` object ID to its descendants."""
|
151
|
+
with self.lock_msg_mapping:
|
152
|
+
self.msg_descendant_objects_mapping.pop(msg_object_id, None)
|
86
153
|
|
87
154
|
def get(self, object_id: str) -> Optional[bytes]:
|
88
155
|
"""Get an object from the store."""
|
89
|
-
|
156
|
+
with self.lock_store:
|
157
|
+
# Check if the object ID is pre-registered
|
158
|
+
if object_id not in self.store:
|
159
|
+
return None
|
160
|
+
|
161
|
+
# Return content (if not yet available, it will b"")
|
162
|
+
return self.store[object_id].content
|
90
163
|
|
91
164
|
def delete(self, object_id: str) -> None:
|
92
|
-
"""Delete an object from the store."""
|
93
|
-
|
94
|
-
|
165
|
+
"""Delete an object and its unreferenced descendants from the store."""
|
166
|
+
with self.lock_store:
|
167
|
+
# If the object is not in the store, nothing to delete
|
168
|
+
if (object_entry := self.store.get(object_id)) is None:
|
169
|
+
return
|
170
|
+
|
171
|
+
# Delete the object if it has no references left
|
172
|
+
if object_entry.ref_count == 0:
|
173
|
+
del self.store[object_id]
|
174
|
+
|
175
|
+
# Remove the object from the run's mapping
|
176
|
+
for run_id in object_entry.runs:
|
177
|
+
self.run_objects_mapping[run_id].discard(object_id)
|
178
|
+
|
179
|
+
# Decrease the reference count of its children
|
180
|
+
children_ids = get_object_children_ids_from_object_content(
|
181
|
+
object_entry.content
|
182
|
+
)
|
183
|
+
for child_id in children_ids:
|
184
|
+
self.store[child_id].ref_count -= 1
|
185
|
+
|
186
|
+
# Recursively try to delete the child object
|
187
|
+
self.delete(child_id)
|
188
|
+
|
189
|
+
def delete_objects_in_run(self, run_id: int) -> None:
|
190
|
+
"""Delete all objects that were registered in a specific run."""
|
191
|
+
with self.lock_store:
|
192
|
+
if run_id not in self.run_objects_mapping:
|
193
|
+
return
|
194
|
+
for object_id in list(self.run_objects_mapping[run_id]):
|
195
|
+
# Check if the object is still in the store
|
196
|
+
if (object_entry := self.store.get(object_id)) is None:
|
197
|
+
continue
|
198
|
+
|
199
|
+
# Remove the run ID from the object's runs
|
200
|
+
object_entry.runs.discard(run_id)
|
201
|
+
|
202
|
+
# Only message objects are allowed to have a `ref_count` of 0,
|
203
|
+
# and every message object must have a `ref_count` of 0
|
204
|
+
if object_entry.ref_count == 0:
|
205
|
+
# Delete the message object and its unreferenced descendants
|
206
|
+
self.delete(object_id)
|
207
|
+
|
208
|
+
# Delete the message's descendants mapping
|
209
|
+
self.delete_message_descendant_ids(object_id)
|
210
|
+
|
211
|
+
# Remove the run from the mapping
|
212
|
+
del self.run_objects_mapping[run_id]
|
95
213
|
|
96
214
|
def clear(self) -> None:
|
97
215
|
"""Clear the store."""
|
98
|
-
self.
|
216
|
+
with self.lock_store:
|
217
|
+
self.store.clear()
|
218
|
+
self.msg_descendant_objects_mapping.clear()
|
219
|
+
self.run_objects_mapping.clear()
|
99
220
|
|
100
221
|
def __contains__(self, object_id: str) -> bool:
|
101
222
|
"""Check if an object_id is in the store."""
|
102
|
-
|
223
|
+
with self.lock_store:
|
224
|
+
return object_id in self.store
|
225
|
+
|
226
|
+
def __len__(self) -> int:
|
227
|
+
"""Get the number of objects in the store."""
|
228
|
+
with self.lock_store:
|
229
|
+
return len(self.store)
|