flwr-nightly 1.19.0.dev20250611__py3-none-any.whl → 1.19.0.dev20250613__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. flwr/cli/ls.py +12 -33
  2. flwr/cli/utils.py +18 -1
  3. flwr/client/grpc_rere_client/connection.py +47 -29
  4. flwr/client/grpc_rere_client/grpc_adapter.py +8 -0
  5. flwr/client/rest_client/connection.py +70 -51
  6. flwr/common/constant.py +4 -0
  7. flwr/common/inflatable.py +24 -0
  8. flwr/common/serde.py +2 -0
  9. flwr/common/typing.py +2 -0
  10. flwr/proto/fleet_pb2.py +12 -16
  11. flwr/proto/fleet_pb2.pyi +4 -19
  12. flwr/proto/fleet_pb2_grpc.py +34 -0
  13. flwr/proto/fleet_pb2_grpc.pyi +13 -0
  14. flwr/proto/message_pb2.py +15 -9
  15. flwr/proto/message_pb2.pyi +41 -0
  16. flwr/proto/run_pb2.py +24 -24
  17. flwr/proto/run_pb2.pyi +4 -1
  18. flwr/proto/serverappio_pb2.py +22 -26
  19. flwr/proto/serverappio_pb2.pyi +4 -19
  20. flwr/proto/serverappio_pb2_grpc.py +34 -0
  21. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  22. flwr/server/app.py +1 -0
  23. flwr/server/grid/grpc_grid.py +20 -9
  24. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -0
  25. flwr/server/superlink/fleet/message_handler/message_handler.py +33 -2
  26. flwr/server/superlink/fleet/rest_rere/rest_api.py +26 -2
  27. flwr/server/superlink/linkstate/in_memory_linkstate.py +20 -3
  28. flwr/server/superlink/linkstate/linkstate.py +6 -2
  29. flwr/server/superlink/linkstate/sqlite_linkstate.py +19 -7
  30. flwr/server/superlink/serverappio/serverappio_servicer.py +65 -29
  31. flwr/server/superlink/simulation/simulationio_servicer.py +2 -1
  32. flwr/server/superlink/utils.py +23 -10
  33. flwr/supercore/object_store/in_memory_object_store.py +160 -33
  34. flwr/supercore/object_store/object_store.py +54 -7
  35. flwr/superexec/deployment.py +6 -2
  36. flwr/superexec/exec_grpc.py +3 -0
  37. flwr/superexec/exec_servicer.py +125 -22
  38. flwr/superexec/executor.py +4 -0
  39. flwr/superexec/simulation.py +7 -1
  40. {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250613.dist-info}/METADATA +1 -1
  41. {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250613.dist-info}/RECORD +43 -43
  42. {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250613.dist-info}/WHEEL +0 -0
  43. {flwr_nightly-1.19.0.dev20250611.dist-info → flwr_nightly-1.19.0.dev20250613.dist-info}/entry_points.txt +0 -0
@@ -44,6 +44,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
44
44
  SendNodeHeartbeatResponse,
45
45
  )
46
46
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
47
+ ConfirmMessageReceivedRequest,
48
+ ConfirmMessageReceivedResponse,
47
49
  ObjectIDs,
48
50
  PullObjectRequest,
49
51
  PullObjectResponse,
@@ -146,6 +148,7 @@ def push_messages(
146
148
  msg.metadata.run_id,
147
149
  [Status.PENDING, Status.STARTING, Status.FINISHED],
148
150
  state,
151
+ store,
149
152
  )
150
153
  if abort_msg:
151
154
  raise InvalidRunStatusException(abort_msg)
@@ -165,7 +168,9 @@ def push_messages(
165
168
  return response
166
169
 
167
170
 
168
- def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
171
+ def get_run(
172
+ request: GetRunRequest, state: LinkState, store: ObjectStore
173
+ ) -> GetRunResponse:
169
174
  """Get run information."""
170
175
  run = state.get_run(request.run_id)
171
176
 
@@ -177,6 +182,7 @@ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
177
182
  request.run_id,
178
183
  [Status.PENDING, Status.STARTING, Status.FINISHED],
179
184
  state,
185
+ store,
180
186
  )
181
187
  if abort_msg:
182
188
  raise InvalidRunStatusException(abort_msg)
@@ -193,7 +199,7 @@ def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
193
199
 
194
200
 
195
201
  def get_fab(
196
- request: GetFabRequest, ffs: Ffs, state: LinkState # pylint: disable=W0613
202
+ request: GetFabRequest, ffs: Ffs, state: LinkState, store: ObjectStore
197
203
  ) -> GetFabResponse:
198
204
  """Get FAB."""
199
205
  # Abort if the run is not running
@@ -201,6 +207,7 @@ def get_fab(
201
207
  request.run_id,
202
208
  [Status.PENDING, Status.STARTING, Status.FINISHED],
203
209
  state,
210
+ store,
204
211
  )
205
212
  if abort_msg:
206
213
  raise InvalidRunStatusException(abort_msg)
@@ -220,6 +227,7 @@ def push_object(
220
227
  request.run_id,
221
228
  [Status.PENDING, Status.STARTING, Status.FINISHED],
222
229
  state,
230
+ store,
223
231
  )
224
232
  if abort_msg:
225
233
  raise InvalidRunStatusException(abort_msg)
@@ -245,6 +253,7 @@ def pull_object(
245
253
  request.run_id,
246
254
  [Status.PENDING, Status.STARTING, Status.FINISHED],
247
255
  state,
256
+ store,
248
257
  )
249
258
  if abort_msg:
250
259
  raise InvalidRunStatusException(abort_msg)
@@ -259,3 +268,25 @@ def pull_object(
259
268
  object_content=content,
260
269
  )
261
270
  return PullObjectResponse(object_found=False, object_available=False)
271
+
272
+
273
+ def confirm_message_received(
274
+ request: ConfirmMessageReceivedRequest,
275
+ state: LinkState,
276
+ store: ObjectStore,
277
+ ) -> ConfirmMessageReceivedResponse:
278
+ """Confirm message received handler."""
279
+ abort_msg = check_abort(
280
+ request.run_id,
281
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
282
+ state,
283
+ store,
284
+ )
285
+ if abort_msg:
286
+ raise InvalidRunStatusException(abort_msg)
287
+
288
+ # Delete the message object
289
+ store.delete(request.message_object_id)
290
+ store.delete_message_descendant_ids(request.message_object_id)
291
+
292
+ return ConfirmMessageReceivedResponse()
@@ -39,6 +39,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
39
39
  SendNodeHeartbeatResponse,
40
40
  )
41
41
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
42
+ ConfirmMessageReceivedRequest,
43
+ ConfirmMessageReceivedResponse,
42
44
  PullObjectRequest,
43
45
  PullObjectResponse,
44
46
  PushObjectRequest,
@@ -176,9 +178,10 @@ async def get_run(request: GetRunRequest) -> GetRunResponse:
176
178
  """GetRun."""
177
179
  # Get state from app
178
180
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
181
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
179
182
 
180
183
  # Handle message
181
- return message_handler.get_run(request=request, state=state)
184
+ return message_handler.get_run(request=request, state=state, store=store)
182
185
 
183
186
 
184
187
  @rest_request_response(GetFabRequest)
@@ -189,9 +192,25 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
189
192
 
190
193
  # Get state from app
191
194
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
195
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
196
+
197
+ # Handle message
198
+ return message_handler.get_fab(request=request, ffs=ffs, state=state, store=store)
199
+
200
+
201
+ @rest_request_response(ConfirmMessageReceivedRequest)
202
+ async def confirm_message_received(
203
+ request: ConfirmMessageReceivedRequest,
204
+ ) -> ConfirmMessageReceivedResponse:
205
+ """Confirm message received."""
206
+ # Get state from app
207
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
208
+ store: ObjectStore = cast(ObjectStoreFactory, app.state.OBJECTSTORE_FACTORY).store()
192
209
 
193
210
  # Handle message
194
- return message_handler.get_fab(request=request, ffs=ffs, state=state)
211
+ return message_handler.confirm_message_received(
212
+ request=request, state=state, store=store
213
+ )
195
214
 
196
215
 
197
216
  routes = [
@@ -204,6 +223,11 @@ routes = [
204
223
  Route("/api/v0/fleet/send-node-heartbeat", send_node_heartbeat, methods=["POST"]),
205
224
  Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
206
225
  Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
226
+ Route(
227
+ "/api/v0/fleet/confirm-message-received",
228
+ confirm_message_received,
229
+ methods=["POST"],
230
+ ),
207
231
  ]
208
232
 
209
233
  app: Starlette = Starlette(
@@ -18,6 +18,7 @@
18
18
  import threading
19
19
  import time
20
20
  from bisect import bisect_right
21
+ from collections import defaultdict
21
22
  from dataclasses import dataclass, field
22
23
  from logging import ERROR, WARNING
23
24
  from typing import Optional
@@ -79,6 +80,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
79
80
  self.message_res_store: dict[str, Message] = {}
80
81
  self.message_ins_id_to_message_res_id: dict[str, str] = {}
81
82
 
83
+ # Map flwr_aid to run_ids for O(1) reverse index lookup
84
+ self.flwr_aid_to_run_ids: dict[str, set[int]] = defaultdict(set)
85
+
82
86
  self.node_public_keys: set[bytes] = set()
83
87
 
84
88
  self.lock = threading.RLock()
@@ -398,6 +402,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
398
402
  fab_hash: Optional[str],
399
403
  override_config: UserConfig,
400
404
  federation_options: ConfigRecord,
405
+ flwr_aid: Optional[str],
401
406
  ) -> int:
402
407
  """Create a new run for the specified `fab_hash`."""
403
408
  # Sample a random int64 as run_id
@@ -421,9 +426,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
421
426
  sub_status="",
422
427
  details="",
423
428
  ),
429
+ flwr_aid=flwr_aid if flwr_aid else "",
424
430
  ),
425
431
  )
426
432
  self.run_ids[run_id] = run_record
433
+ # Add run_id to the flwr_aid_to_run_ids mapping if flwr_aid is provided
434
+ if flwr_aid:
435
+ self.flwr_aid_to_run_ids[flwr_aid].add(run_id)
427
436
 
428
437
  # Record federation options. Leave empty if not passed
429
438
  self.federation_options[run_id] = federation_options
@@ -451,9 +460,15 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
451
460
  with self.lock:
452
461
  return self.node_public_keys.copy()
453
462
 
454
- def get_run_ids(self) -> set[int]:
455
- """Retrieve all run IDs."""
463
+ def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
464
+ """Retrieve all run IDs if `flwr_aid` is not specified.
465
+
466
+ Otherwise, retrieve all run IDs for the specified `flwr_aid`.
467
+ """
456
468
  with self.lock:
469
+ if flwr_aid is not None:
470
+ # Return run IDs for the specified flwr_aid
471
+ return set(self.flwr_aid_to_run_ids.get(flwr_aid, ()))
457
472
  return set(self.run_ids.keys())
458
473
 
459
474
  def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
@@ -463,7 +478,9 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
463
478
  if they have not sent a heartbeat before `active_until`.
464
479
  """
465
480
  current = now()
466
- for record in [self.run_ids[run_id] for run_id in run_ids]:
481
+ for record in (self.run_ids.get(run_id) for run_id in run_ids):
482
+ if record is None:
483
+ continue
467
484
  with record.lock:
468
485
  if record.run.status.status in (Status.STARTING, Status.RUNNING):
469
486
  if record.active_until < current.timestamp():
@@ -164,12 +164,16 @@ class LinkState(abc.ABC): # pylint: disable=R0904
164
164
  fab_hash: Optional[str],
165
165
  override_config: UserConfig,
166
166
  federation_options: ConfigRecord,
167
+ flwr_aid: Optional[str],
167
168
  ) -> int:
168
169
  """Create a new run for the specified `fab_hash`."""
169
170
 
170
171
  @abc.abstractmethod
171
- def get_run_ids(self) -> set[int]:
172
- """Retrieve all run IDs."""
172
+ def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
173
+ """Retrieve all run IDs if `flwr_aid` is not specified.
174
+
175
+ Otherwise, retrieve all run IDs for the specified `flwr_aid`.
176
+ """
173
177
 
174
178
  @abc.abstractmethod
175
179
  def get_run(self, run_id: int) -> Optional[Run]:
@@ -102,7 +102,8 @@ CREATE TABLE IF NOT EXISTS run(
102
102
  finished_at TEXT,
103
103
  sub_status TEXT,
104
104
  details TEXT,
105
- federation_options BLOB
105
+ federation_options BLOB,
106
+ flwr_aid TEXT
106
107
  );
107
108
  """
108
109
 
@@ -719,6 +720,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
719
720
  fab_hash: Optional[str],
720
721
  override_config: UserConfig,
721
722
  federation_options: ConfigRecord,
723
+ flwr_aid: Optional[str],
722
724
  ) -> int:
723
725
  """Create a new run for the specified `fab_id` and `fab_version`."""
724
726
  # Sample a random int64 as run_id
@@ -735,8 +737,8 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
735
737
  "INSERT INTO run "
736
738
  "(run_id, active_until, heartbeat_interval, fab_id, fab_version, "
737
739
  "fab_hash, override_config, federation_options, pending_at, "
738
- "starting_at, running_at, finished_at, sub_status, details) "
739
- "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
740
+ "starting_at, running_at, finished_at, sub_status, details, flwr_aid) "
741
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
740
742
  )
741
743
  override_config_json = json.dumps(override_config)
742
744
  data = [
@@ -754,6 +756,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
754
756
  "",
755
757
  "",
756
758
  "",
759
+ flwr_aid or "",
757
760
  ]
758
761
  self.query(query, tuple(data))
759
762
  return uint64_run_id
@@ -782,10 +785,18 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
782
785
  result: set[bytes] = {row["public_key"] for row in rows}
783
786
  return result
784
787
 
785
- def get_run_ids(self) -> set[int]:
786
- """Retrieve all run IDs."""
787
- query = "SELECT run_id FROM run;"
788
- rows = self.query(query)
788
+ def get_run_ids(self, flwr_aid: Optional[str]) -> set[int]:
789
+ """Retrieve all run IDs if `flwr_aid` is not specified.
790
+
791
+ Otherwise, retrieve all run IDs for the specified `flwr_aid`.
792
+ """
793
+ if flwr_aid:
794
+ rows = self.query(
795
+ "SELECT run_id FROM run WHERE flwr_aid = ?;",
796
+ (flwr_aid,),
797
+ )
798
+ else:
799
+ rows = self.query("SELECT run_id FROM run;", ())
789
800
  return {convert_sint64_to_uint64(row["run_id"]) for row in rows}
790
801
 
791
802
  def _check_and_tag_inactive_run(self, run_ids: set[int]) -> None:
@@ -836,6 +847,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
836
847
  sub_status=row["sub_status"],
837
848
  details=row["details"],
838
849
  ),
850
+ flwr_aid=row["flwr_aid"],
839
851
  )
840
852
  log(ERROR, "`run_id` does not exist.")
841
853
  return None
@@ -26,6 +26,8 @@ from flwr.common.constant import SUPERLINK_NODE_ID, Status
26
26
  from flwr.common.inflatable import (
27
27
  UnexpectedObjectContentError,
28
28
  get_descendant_object_ids,
29
+ get_object_tree,
30
+ no_object_id_recompute,
29
31
  )
30
32
  from flwr.common.logger import log
31
33
  from flwr.common.serde import (
@@ -50,6 +52,8 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
50
52
  PushLogsResponse,
51
53
  )
52
54
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
55
+ ConfirmMessageReceivedRequest,
56
+ ConfirmMessageReceivedResponse,
53
57
  ObjectIDs,
54
58
  PullObjectRequest,
55
59
  PullObjectResponse,
@@ -107,14 +111,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
107
111
  """Get available nodes."""
108
112
  log(DEBUG, "ServerAppIoServicer.GetNodes")
109
113
 
110
- # Init state
111
- state: LinkState = self.state_factory.state()
114
+ # Init state and store
115
+ state = self.state_factory.state()
116
+ store = self.objectstore_factory.store()
112
117
 
113
118
  # Abort if the run is not running
114
119
  abort_if(
115
120
  request.run_id,
116
121
  [Status.PENDING, Status.STARTING, Status.FINISHED],
117
122
  state,
123
+ store,
118
124
  context,
119
125
  )
120
126
 
@@ -128,14 +134,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
128
134
  """Push a set of Messages."""
129
135
  log(DEBUG, "ServerAppIoServicer.PushMessages")
130
136
 
131
- # Init state
132
- state: LinkState = self.state_factory.state()
137
+ # Init state and store
138
+ state = self.state_factory.state()
139
+ store = self.objectstore_factory.store()
133
140
 
134
141
  # Abort if the run is not running
135
142
  abort_if(
136
143
  request.run_id,
137
144
  [Status.PENDING, Status.STARTING, Status.FINISHED],
138
145
  state,
146
+ store,
139
147
  context,
140
148
  )
141
149
 
@@ -146,8 +154,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
146
154
  detail="`messages_list` must not be empty",
147
155
  )
148
156
  message_ids: list[Optional[str]] = []
149
- while request.messages_list:
150
- message_proto = request.messages_list.pop(0)
157
+ for message_proto in request.messages_list:
151
158
  message = message_from_proto(message_proto=message_proto)
152
159
  validation_errors = validate_message(message, is_reply_message=False)
153
160
  _raise_if(
@@ -164,9 +171,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
164
171
  message_id: Optional[str] = state.store_message_ins(message=message)
165
172
  message_ids.append(message_id)
166
173
 
167
- # Init store
168
- store = self.objectstore_factory.store()
169
-
170
174
  # Store Message object to descendants mapping and preregister objects
171
175
  objects_to_push = store_mapping_and_register_objects(store, request=request)
172
176
 
@@ -183,10 +187,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
183
187
  """Pull a set of Messages."""
184
188
  log(DEBUG, "ServerAppIoServicer.PullMessages")
185
189
 
186
- # Init state
187
- state: LinkState = self.state_factory.state()
188
-
189
- # Init store
190
+ # Init state and store
191
+ state = self.state_factory.state()
190
192
  store = self.objectstore_factory.store()
191
193
 
192
194
  # Abort if the run is not running
@@ -194,6 +196,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
194
196
  request.run_id,
195
197
  [Status.PENDING, Status.STARTING, Status.FINISHED],
196
198
  state,
199
+ store,
197
200
  context,
198
201
  )
199
202
 
@@ -205,14 +208,15 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
205
208
  # Register messages generated by LinkState in the Store for consistency
206
209
  for msg_res in messages_res:
207
210
  if msg_res.metadata.src_node_id == SUPERLINK_NODE_ID:
208
- descendants = list(get_descendant_object_ids(msg_res))
209
- message_obj_id = msg_res.metadata.message_id
211
+ with no_object_id_recompute():
212
+ descendants = list(get_descendant_object_ids(msg_res))
213
+ message_obj_id = msg_res.metadata.message_id
210
214
  # Store mapping
211
215
  store.set_message_descendant_ids(
212
216
  msg_object_id=message_obj_id, descendant_ids=descendants
213
217
  )
214
218
  # Preregister
215
- store.preregister(descendants + [message_obj_id])
219
+ store.preregister(request.run_id, get_object_tree(msg_res))
216
220
 
217
221
  # Delete the instruction Messages and their replies if found
218
222
  message_ins_ids_to_delete = {
@@ -328,14 +332,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
328
332
  """Push ServerApp process outputs."""
329
333
  log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
330
334
 
331
- # Init state
335
+ # Init state and store
332
336
  state = self.state_factory.state()
337
+ store = self.objectstore_factory.store()
333
338
 
334
339
  # Abort if the run is not running
335
340
  abort_if(
336
341
  request.run_id,
337
342
  [Status.PENDING, Status.STARTING, Status.FINISHED],
338
343
  state,
344
+ store,
339
345
  context,
340
346
  )
341
347
 
@@ -348,16 +354,23 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
348
354
  """Update the status of a run."""
349
355
  log(DEBUG, "ServerAppIoServicer.UpdateRunStatus")
350
356
 
351
- # Init state
357
+ # Init state and store
352
358
  state = self.state_factory.state()
359
+ store = self.objectstore_factory.store()
353
360
 
354
361
  # Abort if the run is finished
355
- abort_if(request.run_id, [Status.FINISHED], state, context)
362
+ abort_if(request.run_id, [Status.FINISHED], state, store, context)
356
363
 
357
364
  # Update the run status
358
365
  state.update_run_status(
359
366
  run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
360
367
  )
368
+
369
+ # If the run is finished, delete the run from ObjectStore
370
+ if request.run_status.status == Status.FINISHED:
371
+ # Delete all objects related to the run
372
+ store.delete_objects_in_run(request.run_id)
373
+
361
374
  return UpdateRunStatusResponse()
362
375
 
363
376
  def PushLogs(
@@ -412,14 +425,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
412
425
  """Push an object to the ObjectStore."""
413
426
  log(DEBUG, "ServerAppIoServicer.PushObject")
414
427
 
415
- # Init state
416
- state: LinkState = self.state_factory.state()
428
+ # Init state and store
429
+ state = self.state_factory.state()
430
+ store = self.objectstore_factory.store()
417
431
 
418
432
  # Abort if the run is not running
419
433
  abort_if(
420
434
  request.run_id,
421
435
  [Status.PENDING, Status.STARTING, Status.FINISHED],
422
436
  state,
437
+ store,
423
438
  context,
424
439
  )
425
440
 
@@ -427,9 +442,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
427
442
  # Cancel insertion in ObjectStore
428
443
  context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
429
444
 
430
- # Init store
431
- store = self.objectstore_factory.store()
432
-
433
445
  # Insert in store
434
446
  stored = False
435
447
  try:
@@ -449,14 +461,16 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
449
461
  """Pull an object from the ObjectStore."""
450
462
  log(DEBUG, "ServerAppIoServicer.PullObject")
451
463
 
452
- # Init state
453
- state: LinkState = self.state_factory.state()
464
+ # Init state and store
465
+ state = self.state_factory.state()
466
+ store = self.objectstore_factory.store()
454
467
 
455
468
  # Abort if the run is not running
456
469
  abort_if(
457
470
  request.run_id,
458
471
  [Status.PENDING, Status.STARTING, Status.FINISHED],
459
472
  state,
473
+ store,
460
474
  context,
461
475
  )
462
476
 
@@ -464,9 +478,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
464
478
  # Cancel insertion in ObjectStore
465
479
  context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
466
480
 
467
- # Init store
468
- store = self.objectstore_factory.store()
469
-
470
481
  # Fetch from store
471
482
  content = store.get(request.object_id)
472
483
  if content is not None:
@@ -478,6 +489,31 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
478
489
  )
479
490
  return PullObjectResponse(object_found=False, object_available=False)
480
491
 
492
+ def ConfirmMessageReceived(
493
+ self, request: ConfirmMessageReceivedRequest, context: grpc.ServicerContext
494
+ ) -> ConfirmMessageReceivedResponse:
495
+ """Confirm message received."""
496
+ log(DEBUG, "ServerAppIoServicer.ConfirmMessageReceived")
497
+
498
+ # Init state and store
499
+ state = self.state_factory.state()
500
+ store = self.objectstore_factory.store()
501
+
502
+ # Abort if the run is not running
503
+ abort_if(
504
+ request.run_id,
505
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
506
+ state,
507
+ store,
508
+ context,
509
+ )
510
+
511
+ # Delete the message object
512
+ store.delete(request.message_object_id)
513
+ store.delete_message_descendant_ids(request.message_object_id)
514
+
515
+ return ConfirmMessageReceivedResponse()
516
+
481
517
 
482
518
  def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
483
519
  """Raise a `ValueError` with a detailed message if a validation error occurs."""
@@ -121,6 +121,7 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
121
121
  request.run_id,
122
122
  [Status.PENDING, Status.STARTING, Status.FINISHED],
123
123
  state,
124
+ None,
124
125
  context,
125
126
  )
126
127
 
@@ -135,7 +136,7 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
135
136
  state = self.state_factory.state()
136
137
 
137
138
  # Abort if the run is finished
138
- abort_if(request.run_id, [Status.FINISHED], state, context)
139
+ abort_if(request.run_id, [Status.FINISHED], state, None, context)
139
140
 
140
141
  # Update the run status
141
142
  state.update_run_status(
@@ -15,11 +15,12 @@
15
15
  """SuperLink utilities."""
16
16
 
17
17
 
18
- from typing import Union
18
+ from typing import Optional, Union
19
19
 
20
20
  import grpc
21
21
 
22
22
  from flwr.common.constant import Status, SubStatus
23
+ from flwr.common.inflatable import iterate_object_tree
23
24
  from flwr.common.typing import RunStatus
24
25
  from flwr.proto.fleet_pb2 import PushMessagesRequest # pylint: disable=E0611
25
26
  from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
@@ -39,6 +40,7 @@ def check_abort(
39
40
  run_id: int,
40
41
  abort_status_list: list[str],
41
42
  state: LinkState,
43
+ store: Optional[ObjectStore] = None,
42
44
  ) -> Union[str, None]:
43
45
  """Check if the status of the provided `run_id` is in `abort_status_list`."""
44
46
  run_status: RunStatus = state.get_run_status({run_id})[run_id]
@@ -49,6 +51,10 @@ def check_abort(
49
51
  msg += " Stopped by user."
50
52
  return msg
51
53
 
54
+ # Clear the objects of the run from the store if the run is finished
55
+ if store and run_status.status == Status.FINISHED:
56
+ store.delete_objects_in_run(run_id)
57
+
52
58
  return None
53
59
 
54
60
 
@@ -62,10 +68,11 @@ def abort_if(
62
68
  run_id: int,
63
69
  abort_status_list: list[str],
64
70
  state: LinkState,
71
+ store: Optional[ObjectStore],
65
72
  context: grpc.ServicerContext,
66
73
  ) -> None:
67
74
  """Abort context if status of the provided `run_id` is in `abort_status_list`."""
68
- msg = check_abort(run_id, abort_status_list, state)
75
+ msg = check_abort(run_id, abort_status_list, state, store)
69
76
  abort_grpc_context(msg, context)
70
77
 
71
78
 
@@ -73,21 +80,27 @@ def store_mapping_and_register_objects(
73
80
  store: ObjectStore, request: Union[PushInsMessagesRequest, PushMessagesRequest]
74
81
  ) -> dict[str, ObjectIDs]:
75
82
  """Store Message object to descendants mapping and preregister objects."""
83
+ if not request.messages_list:
84
+ return {}
85
+
76
86
  objects_to_push: dict[str, ObjectIDs] = {}
77
- for (
78
- message_obj_id,
79
- descendant_obj_ids,
80
- ) in request.msg_to_descendant_mapping.items():
81
- descendants = list(descendant_obj_ids.object_ids)
87
+
88
+ # Get run_id from the first message in the list
89
+ # All messages of a request should in the same run
90
+ run_id = request.messages_list[0].metadata.run_id
91
+
92
+ for object_tree in request.message_object_trees:
93
+ all_object_ids = [obj.object_id for obj in iterate_object_tree(object_tree)]
94
+ msg_object_id, descendant_ids = all_object_ids[-1], all_object_ids[:-1]
82
95
  # Store mapping
83
96
  store.set_message_descendant_ids(
84
- msg_object_id=message_obj_id, descendant_ids=descendants
97
+ msg_object_id=msg_object_id, descendant_ids=descendant_ids
85
98
  )
86
99
 
87
100
  # Preregister
88
- object_ids_just_registered = store.preregister(descendants + [message_obj_id])
101
+ object_ids_just_registered = store.preregister(run_id, object_tree)
89
102
  # Keep track of objects that need to be pushed
90
- objects_to_push[message_obj_id] = ObjectIDs(
103
+ objects_to_push[msg_object_id] = ObjectIDs(
91
104
  object_ids=object_ids_just_registered
92
105
  )
93
106