flwr 1.19.0__py3-none-any.whl → 1.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. flwr/cli/build.py +15 -5
  2. flwr/cli/new/new.py +12 -4
  3. flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
  4. flwr/cli/new/templates/app/README.md.tpl +5 -0
  5. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -3
  6. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
  7. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
  8. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
  9. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
  10. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
  11. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
  12. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
  13. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
  14. flwr/cli/run/run.py +45 -38
  15. flwr/cli/utils.py +12 -5
  16. flwr/client/grpc_adapter_client/connection.py +11 -4
  17. flwr/client/grpc_rere_client/connection.py +92 -117
  18. flwr/client/rest_client/connection.py +131 -164
  19. flwr/common/constant.py +3 -1
  20. flwr/common/exit/exit_code.py +16 -1
  21. flwr/common/grpc.py +12 -1
  22. flwr/common/{inflatable_grpc_utils.py → inflatable_protobuf_utils.py} +52 -10
  23. flwr/common/inflatable_utils.py +191 -24
  24. flwr/common/record/array.py +101 -22
  25. flwr/common/record/arraychunk.py +59 -0
  26. flwr/common/serde.py +0 -28
  27. flwr/compat/client/app.py +14 -31
  28. flwr/proto/appio_pb2.py +43 -0
  29. flwr/proto/appio_pb2.pyi +151 -0
  30. flwr/proto/appio_pb2_grpc.py +4 -0
  31. flwr/proto/appio_pb2_grpc.pyi +4 -0
  32. flwr/proto/clientappio_pb2.py +12 -19
  33. flwr/proto/clientappio_pb2.pyi +23 -101
  34. flwr/proto/clientappio_pb2_grpc.py +269 -28
  35. flwr/proto/clientappio_pb2_grpc.pyi +114 -20
  36. flwr/proto/fleet_pb2.py +12 -20
  37. flwr/proto/fleet_pb2.pyi +6 -36
  38. flwr/proto/serverappio_pb2.py +8 -31
  39. flwr/proto/serverappio_pb2.pyi +0 -152
  40. flwr/proto/serverappio_pb2_grpc.py +39 -38
  41. flwr/proto/serverappio_pb2_grpc.pyi +21 -20
  42. flwr/server/app.py +1 -1
  43. flwr/server/fleet_event_log_interceptor.py +4 -0
  44. flwr/server/grid/grpc_grid.py +91 -54
  45. flwr/server/serverapp/app.py +27 -17
  46. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +8 -0
  47. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  48. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
  49. flwr/server/superlink/fleet/message_handler/message_handler.py +10 -16
  50. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -2
  51. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
  52. flwr/server/superlink/serverappio/serverappio_servicer.py +35 -43
  53. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  54. flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
  55. flwr/server/superlink/utils.py +0 -35
  56. flwr/simulation/app.py +8 -0
  57. flwr/simulation/run_simulation.py +17 -0
  58. flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
  59. flwr/supercore/grpc_health/__init__.py +22 -0
  60. flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
  61. flwr/supercore/license_plugin/__init__.py +22 -0
  62. flwr/supercore/license_plugin/license_plugin.py +26 -0
  63. flwr/supercore/object_store/in_memory_object_store.py +31 -31
  64. flwr/supercore/object_store/object_store.py +20 -42
  65. flwr/supercore/object_store/utils.py +43 -0
  66. flwr/supercore/scheduler/__init__.py +22 -0
  67. flwr/supercore/scheduler/plugin.py +71 -0
  68. flwr/supercore/utils.py +32 -0
  69. flwr/superexec/deployment.py +1 -2
  70. flwr/superexec/exec_event_log_interceptor.py +4 -0
  71. flwr/superexec/exec_grpc.py +18 -2
  72. flwr/superexec/exec_license_interceptor.py +82 -0
  73. flwr/superexec/exec_servicer.py +10 -1
  74. flwr/superexec/exec_user_auth_interceptor.py +10 -2
  75. flwr/superexec/executor.py +1 -1
  76. flwr/superexec/simulation.py +1 -2
  77. flwr/supernode/cli/flower_supernode.py +0 -7
  78. flwr/supernode/cli/flwr_clientapp.py +10 -3
  79. flwr/supernode/nodestate/in_memory_nodestate.py +11 -2
  80. flwr/supernode/nodestate/nodestate.py +15 -0
  81. flwr/supernode/runtime/run_clientapp.py +110 -33
  82. flwr/supernode/scheduler/__init__.py +22 -0
  83. flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
  84. flwr/supernode/servicer/clientappio/__init__.py +1 -3
  85. flwr/supernode/servicer/clientappio/clientappio_servicer.py +223 -164
  86. flwr/supernode/start_client_internal.py +202 -104
  87. {flwr-1.19.0.dist-info → flwr-1.20.0.dist-info}/METADATA +2 -1
  88. {flwr-1.19.0.dist-info → flwr-1.20.0.dist-info}/RECORD +93 -78
  89. flwr/common/inflatable_rest_utils.py +0 -99
  90. /flwr/{server/superlink → supercore}/ffs/__init__.py +0 -0
  91. /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
  92. /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
  93. {flwr-1.19.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +0 -0
  94. {flwr-1.19.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +0 -0
@@ -22,31 +22,43 @@ from typing import Optional, cast
22
22
 
23
23
  import grpc
24
24
 
25
- from flwr.common import Message, RecordDict
25
+ from flwr.app.error import Error
26
+ from flwr.common import Message, Metadata, RecordDict, now
26
27
  from flwr.common.constant import (
27
28
  SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
28
29
  SUPERLINK_NODE_ID,
30
+ ErrorCode,
31
+ MessageType,
29
32
  )
30
33
  from flwr.common.grpc import create_channel, on_channel_state_change
31
34
  from flwr.common.inflatable import (
35
+ InflatableObject,
32
36
  get_all_nested_objects,
33
37
  get_object_tree,
38
+ iterate_object_tree,
34
39
  no_object_id_recompute,
35
40
  )
36
- from flwr.common.inflatable_grpc_utils import (
37
- make_pull_object_fn_grpc,
38
- make_push_object_fn_grpc,
41
+ from flwr.common.inflatable_protobuf_utils import (
42
+ make_pull_object_fn_protobuf,
43
+ make_push_object_fn_protobuf,
39
44
  )
40
45
  from flwr.common.inflatable_utils import (
46
+ ObjectUnavailableError,
41
47
  inflate_object_from_contents,
42
48
  pull_objects,
43
49
  push_objects,
44
50
  )
45
51
  from flwr.common.logger import log, warn_deprecated_feature
46
- from flwr.common.message import remove_content_from_message
52
+ from flwr.common.message import make_message, remove_content_from_message
47
53
  from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
48
54
  from flwr.common.serde import message_to_proto, run_from_proto
49
55
  from flwr.common.typing import Run
56
+ from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
57
+ PullAppMessagesRequest,
58
+ PullAppMessagesResponse,
59
+ PushAppMessagesRequest,
60
+ PushAppMessagesResponse,
61
+ )
50
62
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
51
63
  ConfirmMessageReceivedRequest,
52
64
  )
@@ -55,10 +67,6 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=
55
67
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
56
68
  GetNodesRequest,
57
69
  GetNodesResponse,
58
- PullResMessagesRequest,
59
- PullResMessagesResponse,
60
- PushInsMessagesRequest,
61
- PushInsMessagesResponse,
62
70
  )
63
71
  from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
64
72
 
@@ -215,37 +223,38 @@ class GrpcGrid(Grid):
215
223
  )
216
224
  return [node.node_id for node in res.nodes]
217
225
 
218
- def _try_push_message(self, run_id: int, message: Message) -> str:
219
- """Push one message and its associated objects."""
220
- # Compute mapping of message descendants
221
- all_objects = get_all_nested_objects(message)
222
- msg_id = message.object_id
223
- object_tree = get_object_tree(message)
226
+ def _try_push_messages(self, run_id: int, messages: Iterable[Message]) -> list[str]:
227
+ """Push all messages and its associated objects."""
228
+ # Prepare all Messages to be sent in a single request
229
+ proto_messages = []
230
+ object_trees = []
231
+ all_objects: dict[str, InflatableObject] = {}
232
+ for msg in messages:
233
+ proto_messages.append(message_to_proto(remove_content_from_message(msg)))
234
+ all_objects.update(get_all_nested_objects(msg))
235
+ object_trees.append(get_object_tree(msg))
236
+ del msg
224
237
 
225
238
  # Call GrpcServerAppIoStub method
226
- res: PushInsMessagesResponse = self._stub.PushMessages(
227
- PushInsMessagesRequest(
228
- messages_list=[message_to_proto(remove_content_from_message(message))],
239
+ res: PushAppMessagesResponse = self._stub.PushMessages(
240
+ PushAppMessagesRequest(
241
+ messages_list=proto_messages,
229
242
  run_id=run_id,
230
- message_object_trees=[object_tree],
243
+ message_object_trees=object_trees,
231
244
  )
232
245
  )
233
246
 
234
247
  # Push objects
235
- # If Message was added to the LinkState correctly
236
- if msg_id is not None:
237
- obj_ids_to_push = set(res.objects_to_push[msg_id].object_ids)
238
- # Push only object that are not in the store
239
- push_objects(
240
- all_objects,
241
- push_object_fn=make_push_object_fn_grpc(
242
- push_object_grpc=self._stub.PushObject,
243
- node=self.node,
244
- run_id=run_id,
245
- ),
246
- object_ids_to_push=obj_ids_to_push,
247
- )
248
- return msg_id
248
+ push_objects(
249
+ all_objects,
250
+ push_object_fn=make_push_object_fn_protobuf(
251
+ push_object_protobuf=self._stub.PushObject,
252
+ node=self.node,
253
+ run_id=run_id,
254
+ ),
255
+ object_ids_to_push=set(res.objects_to_push),
256
+ )
257
+ return cast(list[str], res.message_ids)
249
258
 
250
259
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
251
260
  """Push messages to specified node IDs.
@@ -257,16 +266,16 @@ class GrpcGrid(Grid):
257
266
  run_id = cast(Run, self._run).run_id
258
267
  message_ids: list[str] = []
259
268
  try:
260
- for msg in messages:
261
- # Populate metadata
262
- msg.metadata.__dict__["_run_id"] = run_id
263
- msg.metadata.__dict__["_src_node_id"] = self.node.node_id
264
- msg.metadata.__dict__["_message_id"] = msg.object_id
265
- # Check message
266
- self._check_message(msg)
267
- # Try pushing message and its objects
268
- with no_object_id_recompute():
269
- message_ids.append(self._try_push_message(run_id, msg))
269
+ with no_object_id_recompute():
270
+ for msg in messages:
271
+ # Populate metadata
272
+ msg.metadata.__dict__["_run_id"] = run_id
273
+ msg.metadata.__dict__["_src_node_id"] = self.node.node_id
274
+ msg.metadata.__dict__["_message_id"] = msg.object_id
275
+ # Check message
276
+ self._check_message(msg)
277
+ # Try pushing messages and their objects
278
+ message_ids = self._try_push_messages(run_id, messages)
270
279
 
271
280
  except grpc.RpcError as e:
272
281
  if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
@@ -294,24 +303,52 @@ class GrpcGrid(Grid):
294
303
  run_id = cast(Run, self._run).run_id
295
304
  try:
296
305
  # Pull Messages
297
- res: PullResMessagesResponse = self._stub.PullMessages(
298
- PullResMessagesRequest(
306
+ res: PullAppMessagesResponse = self._stub.PullMessages(
307
+ PullAppMessagesRequest(
299
308
  message_ids=message_ids,
300
309
  run_id=run_id,
301
310
  )
302
311
  )
303
312
  # Pull Messages from store
304
313
  inflated_msgs: list[Message] = []
305
- for msg_proto in res.messages_list:
314
+ for msg_proto, msg_tree in zip(res.messages_list, res.message_object_trees):
306
315
  msg_id = msg_proto.metadata.message_id
307
- all_object_contents = pull_objects(
308
- list(res.objects_to_pull[msg_id].object_ids) + [msg_id],
309
- pull_object_fn=make_pull_object_fn_grpc(
310
- pull_object_grpc=self._stub.PullObject,
311
- node=self.node,
312
- run_id=run_id,
313
- ),
314
- )
316
+ try:
317
+ all_object_contents = pull_objects(
318
+ object_ids=[
319
+ tree.object_id for tree in iterate_object_tree(msg_tree)
320
+ ],
321
+ pull_object_fn=make_pull_object_fn_protobuf(
322
+ pull_object_protobuf=self._stub.PullObject,
323
+ node=self.node,
324
+ run_id=run_id,
325
+ ),
326
+ )
327
+ except ObjectUnavailableError as e:
328
+ # An ObjectUnavailableError indicates that the object is not yet
329
+ # available. If this point has been reached, it means that the
330
+ # Grid has tried to pull the object for the maximum number of times
331
+ # or for the maximum time allowed, so we return an inflated message
332
+ # with an error
333
+ inflated_msgs.append(
334
+ make_message(
335
+ metadata=Metadata(
336
+ run_id=run_id,
337
+ message_id="",
338
+ src_node_id=self.node.node_id,
339
+ dst_node_id=self.node.node_id,
340
+ message_type=MessageType.SYSTEM,
341
+ group_id="",
342
+ ttl=0,
343
+ reply_to_message_id=msg_proto.metadata.reply_to_message_id,
344
+ created_at=now().timestamp(),
345
+ ),
346
+ error=Error(
347
+ code=ErrorCode.MESSAGE_UNAVAILABLE, reason=(str(e))
348
+ ),
349
+ )
350
+ )
351
+ continue
315
352
 
316
353
  # Confirm that the message has been received
317
354
  self._stub.ConfirmMessageReceived(
@@ -16,6 +16,7 @@
16
16
 
17
17
 
18
18
  import argparse
19
+ import gc
19
20
  from logging import DEBUG, ERROR, INFO
20
21
  from pathlib import Path
21
22
  from queue import Queue
@@ -55,12 +56,12 @@ from flwr.common.serde import (
55
56
  )
56
57
  from flwr.common.telemetry import EventType, event
57
58
  from flwr.common.typing import RunNotRunningException, RunStatus
58
- from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
59
- from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
60
- PullServerAppInputsRequest,
61
- PullServerAppInputsResponse,
62
- PushServerAppOutputsRequest,
59
+ from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
60
+ PullAppInputsRequest,
61
+ PullAppInputsResponse,
62
+ PushAppOutputsRequest,
63
63
  )
64
+ from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
64
65
  from flwr.server.grid.grpc_grid import GrpcGrid
65
66
  from flwr.server.run_serverapp import run as run_
66
67
 
@@ -107,11 +108,6 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
107
108
  certificates: Optional[bytes] = None,
108
109
  ) -> None:
109
110
  """Run Flower ServerApp process."""
110
- grid = GrpcGrid(
111
- serverappio_service_address=serverappio_api_address,
112
- root_certificates=certificates,
113
- )
114
-
115
111
  # Resolve directory where FABs are installed
116
112
  flwr_dir_ = get_flwr_dir(flwr_dir)
117
113
  log_uploader = None
@@ -119,13 +115,21 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
119
115
  hash_run_id = None
120
116
  run_status = None
121
117
  heartbeat_sender = None
118
+ grid = None
119
+ context = None
122
120
  while True:
123
121
 
124
122
  try:
123
+ # Initialize the GrpcGrid
124
+ grid = GrpcGrid(
125
+ serverappio_service_address=serverappio_api_address,
126
+ root_certificates=certificates,
127
+ )
128
+
125
129
  # Pull ServerAppInputs from LinkState
126
- req = PullServerAppInputsRequest()
130
+ req = PullAppInputsRequest()
127
131
  log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
128
- res: PullServerAppInputsResponse = grid._stub.PullServerAppInputs(req)
132
+ res: PullAppInputsResponse = grid._stub.PullAppInputs(req)
129
133
  if not res.HasField("run"):
130
134
  sleep(3)
131
135
  run_status = None
@@ -205,10 +209,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
205
209
  # Send resulting context
206
210
  context_proto = context_to_proto(updated_context)
207
211
  log(DEBUG, "[flwr-serverapp] Will push ServerAppOutputs")
208
- out_req = PushServerAppOutputsRequest(
209
- run_id=run.run_id, context=context_proto
210
- )
211
- _ = grid._stub.PushServerAppOutputs(out_req)
212
+ out_req = PushAppOutputsRequest(run_id=run.run_id, context=context_proto)
213
+ _ = grid._stub.PushAppOutputs(out_req)
212
214
 
213
215
  run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
214
216
  except RunNotRunningException:
@@ -236,7 +238,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
236
238
  log_uploader = None
237
239
 
238
240
  # Update run status
239
- if run_status:
241
+ if run_status and grid:
240
242
  run_status_proto = run_status_to_proto(run_status)
241
243
  grid._stub.UpdateRunStatus(
242
244
  UpdateRunStatusRequest(
@@ -244,6 +246,14 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
244
246
  )
245
247
  )
246
248
 
249
+ # Close the Grpc connection
250
+ if grid:
251
+ grid.close()
252
+
253
+ # Clean up the Context
254
+ context = None
255
+ gc.collect()
256
+
247
257
  event(
248
258
  EventType.FLWR_SERVERAPP_RUN_LEAVE,
249
259
  event_details={"run-id-hash": hash_run_id, "success": success},
@@ -41,6 +41,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
41
41
  from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
42
42
  from flwr.proto.heartbeat_pb2 import SendNodeHeartbeatRequest # pylint: disable=E0611
43
43
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
44
+ ConfirmMessageReceivedRequest,
44
45
  PullObjectRequest,
45
46
  PushObjectRequest,
46
47
  )
@@ -101,4 +102,11 @@ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer, FleetService
101
102
  return _handle(request, context, PushObjectRequest, self.PushObject)
102
103
  if request.grpc_message_name == PullObjectRequest.__qualname__:
103
104
  return _handle(request, context, PullObjectRequest, self.PullObject)
105
+ if request.grpc_message_name == ConfirmMessageReceivedRequest.__qualname__:
106
+ return _handle(
107
+ request,
108
+ context,
109
+ ConfirmMessageReceivedRequest,
110
+ self.ConfirmMessageReceived,
111
+ )
104
112
  raise ValueError(f"Invalid grpc_message_name: {request.grpc_message_name}")
@@ -48,10 +48,10 @@ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
48
48
  PushObjectResponse,
49
49
  )
50
50
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
51
- from flwr.server.superlink.ffs.ffs_factory import FfsFactory
52
51
  from flwr.server.superlink.fleet.message_handler import message_handler
53
52
  from flwr.server.superlink.linkstate import LinkStateFactory
54
53
  from flwr.server.superlink.utils import abort_grpc_context
54
+ from flwr.supercore.ffs import FfsFactory
55
55
  from flwr.supercore.object_store import ObjectStoreFactory
56
56
 
57
57
 
@@ -81,12 +81,9 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
81
81
  metadata sent by the node. Continue RPC call if node is authenticated, else,
82
82
  terminate RPC call by setting context to abort.
83
83
  """
84
- # Filter out non-Fleet service calls
84
+ # Only apply to Fleet service
85
85
  if not handler_call_details.method.startswith("/flwr.proto.Fleet/"):
86
- return _unary_unary_rpc_terminator(
87
- "This request should be sent to a different service.",
88
- grpc.StatusCode.FAILED_PRECONDITION,
89
- )
86
+ return continuation(handler_call_details)
90
87
 
91
88
  state = self.state_factory.state()
92
89
  metadata_dict = dict(handler_call_details.invocation_metadata)
@@ -46,7 +46,6 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
46
46
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
47
47
  ConfirmMessageReceivedRequest,
48
48
  ConfirmMessageReceivedResponse,
49
- ObjectIDs,
50
49
  PullObjectRequest,
51
50
  PullObjectResponse,
52
51
  PushObjectRequest,
@@ -58,12 +57,11 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
58
57
  GetRunResponse,
59
58
  Run,
60
59
  )
61
- from flwr.server.superlink.ffs.ffs import Ffs
62
60
  from flwr.server.superlink.linkstate import LinkState
63
61
  from flwr.server.superlink.utils import check_abort
62
+ from flwr.supercore.ffs import Ffs
64
63
  from flwr.supercore.object_store import NoObjectInStoreError, ObjectStore
65
-
66
- from ...utils import store_mapping_and_register_objects
64
+ from flwr.supercore.object_store.utils import store_mapping_and_register_objects
67
65
 
68
66
 
69
67
  def create_node(
@@ -113,25 +111,22 @@ def pull_messages(
113
111
 
114
112
  # Convert to Messages
115
113
  msg_proto = []
116
- objects_to_pull: dict[str, ObjectIDs] = {}
114
+ trees = []
117
115
  for msg in message_list:
118
116
  try:
119
- msg_proto.append(message_to_proto(msg))
120
-
117
+ # Retrieve Message object tree from ObjectStore
121
118
  msg_object_id = msg.metadata.message_id
122
- descendants = store.get_message_descendant_ids(msg_object_id)
123
- # Include the object_id of the message itself
124
- objects_to_pull[msg_object_id] = ObjectIDs(
125
- object_ids=descendants + [msg_object_id]
126
- )
119
+ obj_tree = store.get_object_tree(msg_object_id)
120
+
121
+ # Add Message and its object tree to the response
122
+ msg_proto.append(message_to_proto(msg))
123
+ trees.append(obj_tree)
127
124
  except NoObjectInStoreError as e:
128
125
  log(ERROR, e.message)
129
126
  # Delete message ins from state
130
127
  state.delete_messages(message_ins_ids={msg_object_id})
131
128
 
132
- return PullMessagesResponse(
133
- messages_list=msg_proto, objects_to_pull=objects_to_pull
134
- )
129
+ return PullMessagesResponse(messages_list=msg_proto, message_object_trees=trees)
135
130
 
136
131
 
137
132
  def push_messages(
@@ -287,6 +282,5 @@ def confirm_message_received(
287
282
 
288
283
  # Delete the message object
289
284
  store.delete(request.message_object_id)
290
- store.delete_message_descendant_ids(request.message_object_id)
291
285
 
292
286
  return ConfirmMessageReceivedResponse()
@@ -47,10 +47,9 @@ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
47
47
  PushObjectResponse,
48
48
  )
49
49
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
50
- from flwr.server.superlink.ffs.ffs import Ffs
51
- from flwr.server.superlink.ffs.ffs_factory import FfsFactory
52
50
  from flwr.server.superlink.fleet.message_handler import message_handler
53
51
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
52
+ from flwr.supercore.ffs import Ffs, FfsFactory
54
53
  from flwr.supercore.object_store import ObjectStore, ObjectStoreFactory
55
54
 
56
55
  try:
@@ -26,8 +26,8 @@ from flwr.common.logger import log
26
26
  from flwr.proto.serverappio_pb2_grpc import ( # pylint: disable=E0611
27
27
  add_ServerAppIoServicer_to_server,
28
28
  )
29
- from flwr.server.superlink.ffs.ffs_factory import FfsFactory
30
29
  from flwr.server.superlink.linkstate import LinkStateFactory
30
+ from flwr.supercore.ffs import FfsFactory
31
31
  from flwr.supercore.object_store import ObjectStoreFactory
32
32
 
33
33
  from .serverappio_servicer import ServerAppIoServicer
@@ -42,6 +42,16 @@ from flwr.common.serde import (
42
42
  )
43
43
  from flwr.common.typing import Fab, RunStatus
44
44
  from flwr.proto import serverappio_pb2_grpc # pylint: disable=E0611
45
+ from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
46
+ PullAppInputsRequest,
47
+ PullAppInputsResponse,
48
+ PullAppMessagesRequest,
49
+ PullAppMessagesResponse,
50
+ PushAppMessagesRequest,
51
+ PushAppMessagesResponse,
52
+ PushAppOutputsRequest,
53
+ PushAppOutputsResponse,
54
+ )
45
55
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
46
56
  from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
47
57
  SendAppHeartbeatRequest,
@@ -54,7 +64,6 @@ from flwr.proto.log_pb2 import ( # pylint: disable=E0611
54
64
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
55
65
  ConfirmMessageReceivedRequest,
56
66
  ConfirmMessageReceivedResponse,
57
- ObjectIDs,
58
67
  PullObjectRequest,
59
68
  PullObjectResponse,
60
69
  PushObjectRequest,
@@ -72,23 +81,13 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
72
81
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
73
82
  GetNodesRequest,
74
83
  GetNodesResponse,
75
- PullResMessagesRequest,
76
- PullResMessagesResponse,
77
- PullServerAppInputsRequest,
78
- PullServerAppInputsResponse,
79
- PushInsMessagesRequest,
80
- PushInsMessagesResponse,
81
- PushServerAppOutputsRequest,
82
- PushServerAppOutputsResponse,
83
84
  )
84
- from flwr.server.superlink.ffs.ffs import Ffs
85
- from flwr.server.superlink.ffs.ffs_factory import FfsFactory
86
85
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
87
86
  from flwr.server.superlink.utils import abort_if
88
87
  from flwr.server.utils.validator import validate_message
88
+ from flwr.supercore.ffs import Ffs, FfsFactory
89
89
  from flwr.supercore.object_store import NoObjectInStoreError, ObjectStoreFactory
90
-
91
- from ..utils import store_mapping_and_register_objects
90
+ from flwr.supercore.object_store.utils import store_mapping_and_register_objects
92
91
 
93
92
 
94
93
  class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
@@ -129,8 +128,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
129
128
  return GetNodesResponse(nodes=nodes)
130
129
 
131
130
  def PushMessages(
132
- self, request: PushInsMessagesRequest, context: grpc.ServicerContext
133
- ) -> PushInsMessagesResponse:
131
+ self, request: PushAppMessagesRequest, context: grpc.ServicerContext
132
+ ) -> PushAppMessagesResponse:
134
133
  """Push a set of Messages."""
135
134
  log(DEBUG, "ServerAppIoServicer.PushMessages")
136
135
 
@@ -174,7 +173,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
174
173
  # Store Message object to descendants mapping and preregister objects
175
174
  objects_to_push = store_mapping_and_register_objects(store, request=request)
176
175
 
177
- return PushInsMessagesResponse(
176
+ return PushAppMessagesResponse(
178
177
  message_ids=[
179
178
  str(message_id) if message_id else "" for message_id in message_ids
180
179
  ],
@@ -182,8 +181,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
182
181
  )
183
182
 
184
183
  def PullMessages( # pylint: disable=R0914
185
- self, request: PullResMessagesRequest, context: grpc.ServicerContext
186
- ) -> PullResMessagesResponse:
184
+ self, request: PullAppMessagesRequest, context: grpc.ServicerContext
185
+ ) -> PullAppMessagesResponse:
187
186
  """Pull a set of Messages."""
188
187
  log(DEBUG, "ServerAppIoServicer.PullMessages")
189
188
 
@@ -210,12 +209,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
210
209
  if msg_res.metadata.src_node_id == SUPERLINK_NODE_ID:
211
210
  with no_object_id_recompute():
212
211
  all_objects = get_all_nested_objects(msg_res)
213
- descendants = list(all_objects.keys())[:-1]
214
- message_obj_id = msg_res.metadata.message_id
215
- # Store mapping
216
- store.set_message_descendant_ids(
217
- msg_object_id=message_obj_id, descendant_ids=descendants
218
- )
219
212
  # Preregister
220
213
  store.preregister(request.run_id, get_object_tree(msg_res))
221
214
  # Store objects
@@ -231,7 +224,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
231
224
 
232
225
  # Convert Messages to proto
233
226
  messages_list = []
234
- objects_to_pull: dict[str, ObjectIDs] = {}
227
+ trees = []
235
228
  while messages_res:
236
229
  msg = messages_res.pop(0)
237
230
 
@@ -242,20 +235,20 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
242
235
  request_name="PullMessages",
243
236
  detail="`message.metadata` has mismatched `run_id`",
244
237
  )
245
- messages_list.append(message_to_proto(msg))
246
238
 
247
239
  try:
248
240
  msg_object_id = msg.metadata.message_id
249
- descendants = store.get_message_descendant_ids(msg_object_id)
250
- # Add mapping of message object ID to its descendants
251
- objects_to_pull[msg_object_id] = ObjectIDs(object_ids=descendants)
241
+ obj_tree = store.get_object_tree(msg_object_id)
242
+ # Add message and object tree to the response
243
+ messages_list.append(message_to_proto(msg))
244
+ trees.append(obj_tree)
252
245
  except NoObjectInStoreError as e:
253
246
  log(ERROR, e.message)
254
247
  # Delete message ins from state
255
248
  state.delete_messages(message_ins_ids={msg_object_id})
256
249
 
257
- return PullResMessagesResponse(
258
- messages_list=messages_list, objects_to_pull=objects_to_pull
250
+ return PullAppMessagesResponse(
251
+ messages_list=messages_list, message_object_trees=trees
259
252
  )
260
253
 
261
254
  def GetRun(
@@ -288,11 +281,11 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
288
281
 
289
282
  raise ValueError(f"Found no FAB with hash: {request.hash_str}")
290
283
 
291
- def PullServerAppInputs(
292
- self, request: PullServerAppInputsRequest, context: grpc.ServicerContext
293
- ) -> PullServerAppInputsResponse:
284
+ def PullAppInputs(
285
+ self, request: PullAppInputsRequest, context: grpc.ServicerContext
286
+ ) -> PullAppInputsResponse:
294
287
  """Pull ServerApp process inputs."""
295
- log(DEBUG, "ServerAppIoServicer.PullServerAppInputs")
288
+ log(DEBUG, "ServerAppIoServicer.PullAppInputs")
296
289
  # Init access to LinkState
297
290
  state = self.state_factory.state()
298
291
 
@@ -302,7 +295,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
302
295
  run_id = state.get_pending_run_id()
303
296
  # If there's no pending run, return an empty response
304
297
  if run_id is None:
305
- return PullServerAppInputsResponse()
298
+ return PullAppInputsResponse()
306
299
 
307
300
  # Init access to Ffs
308
301
  ffs = self.ffs_factory.ffs()
@@ -318,7 +311,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
318
311
  # Update run status to STARTING
319
312
  if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
320
313
  log(INFO, "Starting run %d", run_id)
321
- return PullServerAppInputsResponse(
314
+ return PullAppInputsResponse(
322
315
  context=context_to_proto(serverapp_ctxt),
323
316
  run=run_to_proto(run),
324
317
  fab=fab_to_proto(fab),
@@ -328,11 +321,11 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
328
321
  # or if the status cannot be updated to STARTING
329
322
  raise RuntimeError(f"Failed to start run {run_id}")
330
323
 
331
- def PushServerAppOutputs(
332
- self, request: PushServerAppOutputsRequest, context: grpc.ServicerContext
333
- ) -> PushServerAppOutputsResponse:
324
+ def PushAppOutputs(
325
+ self, request: PushAppOutputsRequest, context: grpc.ServicerContext
326
+ ) -> PushAppOutputsResponse:
334
327
  """Push ServerApp process outputs."""
335
- log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
328
+ log(DEBUG, "ServerAppIoServicer.PushAppOutputs")
336
329
 
337
330
  # Init state and store
338
331
  state = self.state_factory.state()
@@ -348,7 +341,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
348
341
  )
349
342
 
350
343
  state.set_serverapp_context(request.run_id, context_from_proto(request.context))
351
- return PushServerAppOutputsResponse()
344
+ return PushAppOutputsResponse()
352
345
 
353
346
  def UpdateRunStatus(
354
347
  self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
@@ -512,7 +505,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
512
505
 
513
506
  # Delete the message object
514
507
  store.delete(request.message_object_id)
515
- store.delete_message_descendant_ids(request.message_object_id)
516
508
 
517
509
  return ConfirmMessageReceivedResponse()
518
510
 
@@ -26,8 +26,8 @@ from flwr.common.logger import log
26
26
  from flwr.proto.simulationio_pb2_grpc import ( # pylint: disable=E0611
27
27
  add_SimulationIoServicer_to_server,
28
28
  )
29
- from flwr.server.superlink.ffs.ffs_factory import FfsFactory
30
29
  from flwr.server.superlink.linkstate import LinkStateFactory
30
+ from flwr.supercore.ffs import FfsFactory
31
31
 
32
32
  from .simulationio_servicer import SimulationIoServicer
33
33
 
@@ -56,9 +56,9 @@ from flwr.proto.simulationio_pb2 import ( # pylint: disable=E0611
56
56
  PushSimulationOutputsRequest,
57
57
  PushSimulationOutputsResponse,
58
58
  )
59
- from flwr.server.superlink.ffs.ffs_factory import FfsFactory
60
59
  from flwr.server.superlink.linkstate import LinkStateFactory
61
60
  from flwr.server.superlink.utils import abort_if
61
+ from flwr.supercore.ffs import FfsFactory
62
62
 
63
63
 
64
64
  class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):