flwr-nightly 1.20.0.dev20250712__py3-none-any.whl → 1.20.0.dev20250715__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. flwr/client/grpc_rere_client/connection.py +3 -1
  2. flwr/client/rest_client/connection.py +3 -1
  3. flwr/common/grpc.py +12 -1
  4. flwr/common/inflatable_utils.py +14 -7
  5. flwr/proto/appio_pb2.py +51 -0
  6. flwr/proto/appio_pb2.pyi +167 -0
  7. flwr/proto/appio_pb2_grpc.py +4 -0
  8. flwr/proto/appio_pb2_grpc.pyi +4 -0
  9. flwr/proto/clientappio_pb2.py +19 -11
  10. flwr/proto/clientappio_pb2.pyi +50 -12
  11. flwr/proto/clientappio_pb2_grpc.py +68 -0
  12. flwr/proto/clientappio_pb2_grpc.pyi +26 -0
  13. flwr/proto/fleet_pb2.py +14 -18
  14. flwr/proto/fleet_pb2.pyi +4 -19
  15. flwr/proto/serverappio_pb2.py +8 -31
  16. flwr/proto/serverappio_pb2.pyi +0 -152
  17. flwr/proto/serverappio_pb2_grpc.py +39 -38
  18. flwr/proto/serverappio_pb2_grpc.pyi +21 -20
  19. flwr/server/grid/grpc_grid.py +10 -8
  20. flwr/server/serverapp/app.py +9 -11
  21. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -13
  22. flwr/server/superlink/serverappio/serverappio_servicer.py +31 -33
  23. flwr/server/superlink/utils.py +3 -11
  24. flwr/supercore/grpc_health/__init__.py +22 -0
  25. flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
  26. flwr/supercore/object_store/in_memory_object_store.py +31 -31
  27. flwr/supercore/object_store/object_store.py +16 -40
  28. flwr/supernode/runtime/run_clientapp.py +14 -4
  29. flwr/supernode/servicer/clientappio/clientappio_servicer.py +48 -5
  30. flwr/supernode/start_client_internal.py +14 -0
  31. {flwr_nightly-1.20.0.dev20250712.dist-info → flwr_nightly-1.20.0.dev20250715.dist-info}/METADATA +2 -1
  32. {flwr_nightly-1.20.0.dev20250712.dist-info → flwr_nightly-1.20.0.dev20250715.dist-info}/RECORD +34 -28
  33. {flwr_nightly-1.20.0.dev20250712.dist-info → flwr_nightly-1.20.0.dev20250715.dist-info}/WHEEL +0 -0
  34. {flwr_nightly-1.20.0.dev20250712.dist-info → flwr_nightly-1.20.0.dev20250715.dist-info}/entry_points.txt +0 -0
@@ -55,12 +55,12 @@ from flwr.common.serde import (
55
55
  )
56
56
  from flwr.common.telemetry import EventType, event
57
57
  from flwr.common.typing import RunNotRunningException, RunStatus
58
- from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
59
- from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
60
- PullServerAppInputsRequest,
61
- PullServerAppInputsResponse,
62
- PushServerAppOutputsRequest,
58
+ from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
59
+ PullAppInputsRequest,
60
+ PullAppInputsResponse,
61
+ PushAppOutputsRequest,
63
62
  )
63
+ from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
64
64
  from flwr.server.grid.grpc_grid import GrpcGrid
65
65
  from flwr.server.run_serverapp import run as run_
66
66
 
@@ -125,9 +125,9 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
125
125
  )
126
126
 
127
127
  # Pull ServerAppInputs from LinkState
128
- req = PullServerAppInputsRequest()
128
+ req = PullAppInputsRequest()
129
129
  log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
130
- res: PullServerAppInputsResponse = grid._stub.PullServerAppInputs(req)
130
+ res: PullAppInputsResponse = grid._stub.PullAppInputs(req)
131
131
  if not res.HasField("run"):
132
132
  sleep(3)
133
133
  run_status = None
@@ -207,10 +207,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
207
207
  # Send resulting context
208
208
  context_proto = context_to_proto(updated_context)
209
209
  log(DEBUG, "[flwr-serverapp] Will push ServerAppOutputs")
210
- out_req = PushServerAppOutputsRequest(
211
- run_id=run.run_id, context=context_proto
212
- )
213
- _ = grid._stub.PushServerAppOutputs(out_req)
210
+ out_req = PushAppOutputsRequest(run_id=run.run_id, context=context_proto)
211
+ _ = grid._stub.PushAppOutputs(out_req)
214
212
 
215
213
  run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
216
214
  except RunNotRunningException:
@@ -46,7 +46,6 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
46
46
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
47
47
  ConfirmMessageReceivedRequest,
48
48
  ConfirmMessageReceivedResponse,
49
- ObjectIDs,
50
49
  PullObjectRequest,
51
50
  PullObjectResponse,
52
51
  PushObjectRequest,
@@ -113,25 +112,22 @@ def pull_messages(
113
112
 
114
113
  # Convert to Messages
115
114
  msg_proto = []
116
- objects_to_pull: dict[str, ObjectIDs] = {}
115
+ trees = []
117
116
  for msg in message_list:
118
117
  try:
119
- msg_proto.append(message_to_proto(msg))
120
-
118
+ # Retrieve Message object tree from ObjectStore
121
119
  msg_object_id = msg.metadata.message_id
122
- descendants = store.get_message_descendant_ids(msg_object_id)
123
- # Include the object_id of the message itself
124
- objects_to_pull[msg_object_id] = ObjectIDs(
125
- object_ids=descendants + [msg_object_id]
126
- )
120
+ obj_tree = store.get_object_tree(msg_object_id)
121
+
122
+ # Add Message and its object tree to the response
123
+ msg_proto.append(message_to_proto(msg))
124
+ trees.append(obj_tree)
127
125
  except NoObjectInStoreError as e:
128
126
  log(ERROR, e.message)
129
127
  # Delete message ins from state
130
128
  state.delete_messages(message_ins_ids={msg_object_id})
131
129
 
132
- return PullMessagesResponse(
133
- messages_list=msg_proto, objects_to_pull=objects_to_pull
134
- )
130
+ return PullMessagesResponse(messages_list=msg_proto, message_object_trees=trees)
135
131
 
136
132
 
137
133
  def push_messages(
@@ -287,6 +283,5 @@ def confirm_message_received(
287
283
 
288
284
  # Delete the message object
289
285
  store.delete(request.message_object_id)
290
- store.delete_message_descendant_ids(request.message_object_id)
291
286
 
292
287
  return ConfirmMessageReceivedResponse()
@@ -27,6 +27,7 @@ from flwr.common.inflatable import (
27
27
  UnexpectedObjectContentError,
28
28
  get_all_nested_objects,
29
29
  get_object_tree,
30
+ iterate_object_tree,
30
31
  no_object_id_recompute,
31
32
  )
32
33
  from flwr.common.logger import log
@@ -42,6 +43,16 @@ from flwr.common.serde import (
42
43
  )
43
44
  from flwr.common.typing import Fab, RunStatus
44
45
  from flwr.proto import serverappio_pb2_grpc # pylint: disable=E0611
46
+ from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
47
+ PullAppInputsRequest,
48
+ PullAppInputsResponse,
49
+ PullAppMessagesRequest,
50
+ PullAppMessagesResponse,
51
+ PushAppMessagesRequest,
52
+ PushAppMessagesResponse,
53
+ PushAppOutputsRequest,
54
+ PushAppOutputsResponse,
55
+ )
45
56
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
46
57
  from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
47
58
  SendAppHeartbeatRequest,
@@ -72,14 +83,6 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
72
83
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
73
84
  GetNodesRequest,
74
85
  GetNodesResponse,
75
- PullResMessagesRequest,
76
- PullResMessagesResponse,
77
- PullServerAppInputsRequest,
78
- PullServerAppInputsResponse,
79
- PushInsMessagesRequest,
80
- PushInsMessagesResponse,
81
- PushServerAppOutputsRequest,
82
- PushServerAppOutputsResponse,
83
86
  )
84
87
  from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
85
88
  from flwr.server.superlink.utils import abort_if
@@ -128,8 +131,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
128
131
  return GetNodesResponse(nodes=nodes)
129
132
 
130
133
  def PushMessages(
131
- self, request: PushInsMessagesRequest, context: grpc.ServicerContext
132
- ) -> PushInsMessagesResponse:
134
+ self, request: PushAppMessagesRequest, context: grpc.ServicerContext
135
+ ) -> PushAppMessagesResponse:
133
136
  """Push a set of Messages."""
134
137
  log(DEBUG, "ServerAppIoServicer.PushMessages")
135
138
 
@@ -173,7 +176,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
173
176
  # Store Message object to descendants mapping and preregister objects
174
177
  objects_to_push = store_mapping_and_register_objects(store, request=request)
175
178
 
176
- return PushInsMessagesResponse(
179
+ return PushAppMessagesResponse(
177
180
  message_ids=[
178
181
  str(message_id) if message_id else "" for message_id in message_ids
179
182
  ],
@@ -181,8 +184,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
181
184
  )
182
185
 
183
186
  def PullMessages( # pylint: disable=R0914
184
- self, request: PullResMessagesRequest, context: grpc.ServicerContext
185
- ) -> PullResMessagesResponse:
187
+ self, request: PullAppMessagesRequest, context: grpc.ServicerContext
188
+ ) -> PullAppMessagesResponse:
186
189
  """Pull a set of Messages."""
187
190
  log(DEBUG, "ServerAppIoServicer.PullMessages")
188
191
 
@@ -209,12 +212,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
209
212
  if msg_res.metadata.src_node_id == SUPERLINK_NODE_ID:
210
213
  with no_object_id_recompute():
211
214
  all_objects = get_all_nested_objects(msg_res)
212
- descendants = list(all_objects.keys())[:-1]
213
- message_obj_id = msg_res.metadata.message_id
214
- # Store mapping
215
- store.set_message_descendant_ids(
216
- msg_object_id=message_obj_id, descendant_ids=descendants
217
- )
218
215
  # Preregister
219
216
  store.preregister(request.run_id, get_object_tree(msg_res))
220
217
  # Store objects
@@ -245,7 +242,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
245
242
 
246
243
  try:
247
244
  msg_object_id = msg.metadata.message_id
248
- descendants = store.get_message_descendant_ids(msg_object_id)
245
+ obj_tree = store.get_object_tree(msg_object_id)
246
+ descendants = [node.object_id for node in iterate_object_tree(obj_tree)]
247
+ descendants = descendants[:-1] # Exclude the message itself
249
248
  # Add mapping of message object ID to its descendants
250
249
  objects_to_pull[msg_object_id] = ObjectIDs(object_ids=descendants)
251
250
  except NoObjectInStoreError as e:
@@ -253,7 +252,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
253
252
  # Delete message ins from state
254
253
  state.delete_messages(message_ins_ids={msg_object_id})
255
254
 
256
- return PullResMessagesResponse(
255
+ return PullAppMessagesResponse(
257
256
  messages_list=messages_list, objects_to_pull=objects_to_pull
258
257
  )
259
258
 
@@ -287,11 +286,11 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
287
286
 
288
287
  raise ValueError(f"Found no FAB with hash: {request.hash_str}")
289
288
 
290
- def PullServerAppInputs(
291
- self, request: PullServerAppInputsRequest, context: grpc.ServicerContext
292
- ) -> PullServerAppInputsResponse:
289
+ def PullAppInputs(
290
+ self, request: PullAppInputsRequest, context: grpc.ServicerContext
291
+ ) -> PullAppInputsResponse:
293
292
  """Pull ServerApp process inputs."""
294
- log(DEBUG, "ServerAppIoServicer.PullServerAppInputs")
293
+ log(DEBUG, "ServerAppIoServicer.PullAppInputs")
295
294
  # Init access to LinkState
296
295
  state = self.state_factory.state()
297
296
 
@@ -301,7 +300,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
301
300
  run_id = state.get_pending_run_id()
302
301
  # If there's no pending run, return an empty response
303
302
  if run_id is None:
304
- return PullServerAppInputsResponse()
303
+ return PullAppInputsResponse()
305
304
 
306
305
  # Init access to Ffs
307
306
  ffs = self.ffs_factory.ffs()
@@ -317,7 +316,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
317
316
  # Update run status to STARTING
318
317
  if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
319
318
  log(INFO, "Starting run %d", run_id)
320
- return PullServerAppInputsResponse(
319
+ return PullAppInputsResponse(
321
320
  context=context_to_proto(serverapp_ctxt),
322
321
  run=run_to_proto(run),
323
322
  fab=fab_to_proto(fab),
@@ -327,11 +326,11 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
327
326
  # or if the status cannot be updated to STARTING
328
327
  raise RuntimeError(f"Failed to start run {run_id}")
329
328
 
330
- def PushServerAppOutputs(
331
- self, request: PushServerAppOutputsRequest, context: grpc.ServicerContext
332
- ) -> PushServerAppOutputsResponse:
329
+ def PushAppOutputs(
330
+ self, request: PushAppOutputsRequest, context: grpc.ServicerContext
331
+ ) -> PushAppOutputsResponse:
333
332
  """Push ServerApp process outputs."""
334
- log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
333
+ log(DEBUG, "ServerAppIoServicer.PushAppOutputs")
335
334
 
336
335
  # Init state and store
337
336
  state = self.state_factory.state()
@@ -347,7 +346,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
347
346
  )
348
347
 
349
348
  state.set_serverapp_context(request.run_id, context_from_proto(request.context))
350
- return PushServerAppOutputsResponse()
349
+ return PushAppOutputsResponse()
351
350
 
352
351
  def UpdateRunStatus(
353
352
  self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
@@ -511,7 +510,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
511
510
 
512
511
  # Delete the message object
513
512
  store.delete(request.message_object_id)
514
- store.delete_message_descendant_ids(request.message_object_id)
515
513
 
516
514
  return ConfirmMessageReceivedResponse()
517
515
 
@@ -20,11 +20,10 @@ from typing import Optional, Union
20
20
  import grpc
21
21
 
22
22
  from flwr.common.constant import Status, SubStatus
23
- from flwr.common.inflatable import iterate_object_tree
24
23
  from flwr.common.typing import RunStatus
24
+ from flwr.proto.appio_pb2 import PushAppMessagesRequest # pylint: disable=E0611
25
25
  from flwr.proto.fleet_pb2 import PushMessagesRequest # pylint: disable=E0611
26
26
  from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
27
- from flwr.proto.serverappio_pb2 import PushInsMessagesRequest # pylint: disable=E0611
28
27
  from flwr.server.superlink.linkstate import LinkState
29
28
  from flwr.supercore.object_store import ObjectStore
30
29
 
@@ -77,7 +76,7 @@ def abort_if(
77
76
 
78
77
 
79
78
  def store_mapping_and_register_objects(
80
- store: ObjectStore, request: Union[PushInsMessagesRequest, PushMessagesRequest]
79
+ store: ObjectStore, request: Union[PushAppMessagesRequest, PushMessagesRequest]
81
80
  ) -> dict[str, ObjectIDs]:
82
81
  """Store Message object to descendants mapping and preregister objects."""
83
82
  if not request.messages_list:
@@ -90,17 +89,10 @@ def store_mapping_and_register_objects(
90
89
  run_id = request.messages_list[0].metadata.run_id
91
90
 
92
91
  for object_tree in request.message_object_trees:
93
- all_object_ids = [obj.object_id for obj in iterate_object_tree(object_tree)]
94
- msg_object_id, descendant_ids = all_object_ids[-1], all_object_ids[:-1]
95
- # Store mapping
96
- store.set_message_descendant_ids(
97
- msg_object_id=msg_object_id, descendant_ids=descendant_ids
98
- )
99
-
100
92
  # Preregister
101
93
  object_ids_just_registered = store.preregister(run_id, object_tree)
102
94
  # Keep track of objects that need to be pushed
103
- objects_to_push[msg_object_id] = ObjectIDs(
95
+ objects_to_push[object_tree.object_id] = ObjectIDs(
104
96
  object_ids=object_ids_just_registered
105
97
  )
106
98
 
@@ -0,0 +1,22 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """GRPC health servicers."""
16
+
17
+
18
+ from .simple_health_servicer import SimpleHealthServicer
19
+
20
+ __all__ = [
21
+ "SimpleHealthServicer",
22
+ ]
@@ -0,0 +1,38 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Simple gRPC health servicers."""
16
+
17
+
18
+ import grpc
19
+
20
+ # pylint: disable=E0611
21
+ from grpc_health.v1.health_pb2 import HealthCheckRequest, HealthCheckResponse
22
+ from grpc_health.v1.health_pb2_grpc import HealthServicer
23
+
24
+ # pylint: enable=E0611
25
+
26
+
27
+ class SimpleHealthServicer(HealthServicer): # type: ignore
28
+ """A simple gRPC health servicer that always returns SERVING."""
29
+
30
+ def Check(
31
+ self, request: HealthCheckRequest, context: grpc.RpcContext
32
+ ) -> HealthCheckResponse:
33
+ """Return a HealthCheckResponse with SERVING status."""
34
+ return HealthCheckResponse(status=HealthCheckResponse.SERVING)
35
+
36
+ def Watch(self, request: HealthCheckRequest, context: grpc.RpcContext) -> None:
37
+ """Watch the health status (not implemented)."""
38
+ context.abort(grpc.StatusCode.UNIMPLEMENTED, "Watch is not implemented")
@@ -20,7 +20,6 @@ from dataclasses import dataclass
20
20
  from typing import Optional
21
21
 
22
22
  from flwr.common.inflatable import (
23
- get_object_children_ids_from_object_content,
24
23
  get_object_id,
25
24
  is_valid_sha256_hash,
26
25
  iterate_object_tree,
@@ -37,6 +36,7 @@ class ObjectEntry:
37
36
 
38
37
  content: bytes
39
38
  is_available: bool
39
+ child_object_ids: list[str] # List of child object IDs
40
40
  ref_count: int # Number of references (direct parents) to this object
41
41
  runs: set[int] # Set of run IDs that used this object
42
42
 
@@ -70,6 +70,9 @@ class InMemoryObjectStore(ObjectStore):
70
70
  self.store[obj_id] = ObjectEntry(
71
71
  content=b"", # Initially empty content
72
72
  is_available=False, # Initially not available
73
+ child_object_ids=[ # List of child object IDs
74
+ child.object_id for child in tree_node.children
75
+ ],
73
76
  ref_count=0, # Reference count starts at 0
74
77
  runs={run_id}, # Start with the current run ID
75
78
  )
@@ -102,6 +105,32 @@ class InMemoryObjectStore(ObjectStore):
102
105
 
103
106
  return new_objects
104
107
 
108
+ def get_object_tree(self, object_id: str) -> ObjectTree:
109
+ """Get the object tree for a given object ID."""
110
+ with self.lock_store:
111
+ # Raise an exception if there's no object with the given ID
112
+ if not (object_entry := self.store.get(object_id)):
113
+ raise NoObjectInStoreError(
114
+ f"Object with ID '{object_id}' was not pre-registered."
115
+ )
116
+
117
+ # Build the object trees of all children
118
+ try:
119
+ child_trees = [
120
+ self.get_object_tree(child_id)
121
+ for child_id in object_entry.child_object_ids
122
+ ]
123
+ except NoObjectInStoreError as e:
124
+ # Raise an error if any child object is missing
125
+ # This indicates an integrity issue
126
+ raise NoObjectInStoreError(
127
+ f"Object tree for object ID '{object_id}' contains missing "
128
+ "children. This may indicate a corrupted object store."
129
+ ) from e
130
+
131
+ # Create and return the ObjectTree for the current object
132
+ return ObjectTree(object_id=object_id, children=child_trees)
133
+
105
134
  def put(self, object_id: str, object_content: bytes) -> None:
106
135
  """Put an object into the store."""
107
136
  if self.verify:
@@ -128,29 +157,6 @@ class InMemoryObjectStore(ObjectStore):
128
157
  self.store[object_id].content = object_content
129
158
  self.store[object_id].is_available = True
130
159
 
131
- def set_message_descendant_ids(
132
- self, msg_object_id: str, descendant_ids: list[str]
133
- ) -> None:
134
- """Store the mapping from a ``Message`` object ID to the object IDs of its
135
- descendants."""
136
- with self.lock_msg_mapping:
137
- self.msg_descendant_objects_mapping[msg_object_id] = descendant_ids
138
-
139
- def get_message_descendant_ids(self, msg_object_id: str) -> list[str]:
140
- """Retrieve the object IDs of all descendants of a given Message."""
141
- with self.lock_msg_mapping:
142
- if msg_object_id not in self.msg_descendant_objects_mapping:
143
- raise NoObjectInStoreError(
144
- f"No message registered in Object Store with ID '{msg_object_id}'. "
145
- "Mapping to descendants could not be found."
146
- )
147
- return self.msg_descendant_objects_mapping[msg_object_id]
148
-
149
- def delete_message_descendant_ids(self, msg_object_id: str) -> None:
150
- """Delete the mapping from a ``Message`` object ID to its descendants."""
151
- with self.lock_msg_mapping:
152
- self.msg_descendant_objects_mapping.pop(msg_object_id, None)
153
-
154
160
  def get(self, object_id: str) -> Optional[bytes]:
155
161
  """Get an object from the store."""
156
162
  with self.lock_store:
@@ -177,10 +183,7 @@ class InMemoryObjectStore(ObjectStore):
177
183
  self.run_objects_mapping[run_id].discard(object_id)
178
184
 
179
185
  # Decrease the reference count of its children
180
- children_ids = get_object_children_ids_from_object_content(
181
- object_entry.content
182
- )
183
- for child_id in children_ids:
186
+ for child_id in object_entry.child_object_ids:
184
187
  self.store[child_id].ref_count -= 1
185
188
 
186
189
  # Recursively try to delete the child object
@@ -205,9 +208,6 @@ class InMemoryObjectStore(ObjectStore):
205
208
  # Delete the message object and its unreferenced descendants
206
209
  self.delete(object_id)
207
210
 
208
- # Delete the message's descendants mapping
209
- self.delete_message_descendant_ids(object_id)
210
-
211
211
  # Remove the run from the mapping
212
212
  del self.run_objects_mapping[run_id]
213
213
 
@@ -60,6 +60,22 @@ class ObjectStore(abc.ABC):
60
60
  in the `ObjectStore`, or were preregistered but are not yet available.
61
61
  """
62
62
 
63
+ @abc.abstractmethod
64
+ def get_object_tree(self, object_id: str) -> ObjectTree:
65
+ """Get the object tree for a given object ID.
66
+
67
+ Parameters
68
+ ----------
69
+ object_id : str
70
+ The ID of the object for which to retrieve the object tree.
71
+
72
+ Returns
73
+ -------
74
+ ObjectTree
75
+ An ObjectTree representing the hierarchical structure of the object with
76
+ the given ID and its descendants.
77
+ """
78
+
63
79
  @abc.abstractmethod
64
80
  def put(self, object_id: str, object_content: bytes) -> None:
65
81
  """Put an object into the store.
@@ -126,46 +142,6 @@ class ObjectStore(abc.ABC):
126
142
  This method should remove all objects from the store.
127
143
  """
128
144
 
129
- @abc.abstractmethod
130
- def set_message_descendant_ids(
131
- self, msg_object_id: str, descendant_ids: list[str]
132
- ) -> None:
133
- """Store the mapping from a ``Message`` object ID to the object IDs of its
134
- descendants.
135
-
136
- Parameters
137
- ----------
138
- msg_object_id : str
139
- The object ID of the ``Message``.
140
- descendant_ids : list[str]
141
- A list of object IDs representing all descendant objects of the ``Message``.
142
- """
143
-
144
- @abc.abstractmethod
145
- def get_message_descendant_ids(self, msg_object_id: str) -> list[str]:
146
- """Retrieve the object IDs of all descendants of a given ``Message``.
147
-
148
- Parameters
149
- ----------
150
- msg_object_id : str
151
- The object ID of the ``Message``.
152
-
153
- Returns
154
- -------
155
- list[str]
156
- A list of object IDs of all descendant objects of the ``Message``.
157
- """
158
-
159
- @abc.abstractmethod
160
- def delete_message_descendant_ids(self, msg_object_id: str) -> None:
161
- """Delete the mapping from a ``Message`` object ID to its descendants.
162
-
163
- Parameters
164
- ----------
165
- msg_object_id : str
166
- The object ID of the ``Message``.
167
- """
168
-
169
145
  @abc.abstractmethod
170
146
  def __contains__(self, object_id: str) -> bool:
171
147
  """Check if an object_id is in the store.
@@ -50,8 +50,11 @@ from flwr.proto.clientappio_pb2 import (
50
50
  GetRunIdsWithPendingMessagesResponse,
51
51
  PullClientAppInputsRequest,
52
52
  PullClientAppInputsResponse,
53
+ PullMessageRequest,
54
+ PullMessageResponse,
53
55
  PushClientAppOutputsRequest,
54
56
  PushClientAppOutputsResponse,
57
+ PushMessageRequest,
55
58
  RequestTokenRequest,
56
59
  RequestTokenResponse,
57
60
  )
@@ -199,10 +202,14 @@ def pull_clientappinputs(
199
202
  masked_token = mask_string(token)
200
203
  log(INFO, "[flwr-clientapp] Pull `ClientAppInputs` for token %s", masked_token)
201
204
  try:
205
+ # Pull Message
206
+ res_msg: PullMessageResponse = stub.PullMessage(PullMessageRequest(token=token))
207
+ message = message_from_proto(res_msg.message)
208
+
209
+ # Pull Context, Run and (optional) FAB
202
210
  res: PullClientAppInputsResponse = stub.PullClientAppInputs(
203
211
  PullClientAppInputsRequest(token=token)
204
212
  )
205
- message = message_from_proto(res.message)
206
213
  context = context_from_proto(res.context)
207
214
  run = run_from_proto(res.run)
208
215
  fab = fab_from_proto(res.fab) if res.fab else None
@@ -224,10 +231,13 @@ def push_clientappoutputs(
224
231
  proto_context = context_to_proto(context)
225
232
 
226
233
  try:
234
+
235
+ # Push Message
236
+ _ = stub.PushMessage(PushMessageRequest(token=token, message=proto_message))
237
+
238
+ # Push Context
227
239
  res: PushClientAppOutputsResponse = stub.PushClientAppOutputs(
228
- PushClientAppOutputsRequest(
229
- token=token, message=proto_message, context=proto_context
230
- )
240
+ PushClientAppOutputsRequest(token=token, context=proto_context)
231
241
  )
232
242
  return res
233
243
  except grpc.RpcError as e:
@@ -39,8 +39,12 @@ from flwr.proto.clientappio_pb2 import ( # pylint: disable=E0401
39
39
  GetRunIdsWithPendingMessagesResponse,
40
40
  PullClientAppInputsRequest,
41
41
  PullClientAppInputsResponse,
42
+ PullMessageRequest,
43
+ PullMessageResponse,
42
44
  PushClientAppOutputsRequest,
43
45
  PushClientAppOutputsResponse,
46
+ PushMessageRequest,
47
+ PushMessageResponse,
44
48
  RequestTokenRequest,
45
49
  RequestTokenResponse,
46
50
  )
@@ -119,14 +123,12 @@ class ClientAppIoServicer(clientappio_pb2_grpc.ClientAppIoServicer):
119
123
  )
120
124
  raise RuntimeError("This line should never be reached.")
121
125
 
122
- # Retrieve message, context, run and fab for this run
123
- message = state.get_messages(run_ids=[run_id], is_reply=False)[0]
126
+ # Retrieve context, run and fab for this run
124
127
  context = cast(Context, state.get_context(run_id))
125
128
  run = cast(Run, state.get_run(run_id))
126
129
  fab = Fab(run.fab_hash, ffs.get(run.fab_hash)[0]) # type: ignore
127
130
 
128
131
  return PullClientAppInputsResponse(
129
- message=message_to_proto(message),
130
132
  context=context_to_proto(context),
131
133
  run=run_to_proto(run),
132
134
  fab=fab_to_proto(fab),
@@ -150,8 +152,7 @@ class ClientAppIoServicer(clientappio_pb2_grpc.ClientAppIoServicer):
150
152
  )
151
153
  raise RuntimeError("This line should never be reached.")
152
154
 
153
- # Save the message and context to the state
154
- state.store_message(message_from_proto(request.message))
155
+ # Save the context to the state
155
156
  state.store_context(context_from_proto(request.context))
156
157
 
157
158
  # Remove the token to make the run eligible for processing
@@ -159,3 +160,45 @@ class ClientAppIoServicer(clientappio_pb2_grpc.ClientAppIoServicer):
159
160
  state.delete_token(run_id)
160
161
 
161
162
  return PushClientAppOutputsResponse()
163
+
164
+ def PullMessage(
165
+ self, request: PullMessageRequest, context: grpc.ServicerContext
166
+ ) -> PullMessageResponse:
167
+ """Pull one Message."""
168
+ # Initialize state and ffs connection
169
+ state = self.state_factory.state()
170
+
171
+ # Validate the token
172
+ run_id = state.get_run_id_by_token(request.token)
173
+ if run_id is None or not state.verify_token(run_id, request.token):
174
+ context.abort(
175
+ grpc.StatusCode.PERMISSION_DENIED,
176
+ "Invalid token.",
177
+ )
178
+ raise RuntimeError("This line should never be reached.")
179
+
180
+ # Retrieve message, context, run and fab for this run
181
+ message = state.get_messages(run_ids=[run_id], is_reply=False)[0]
182
+
183
+ return PullMessageResponse(message=message_to_proto(message))
184
+
185
+ def PushMessage(
186
+ self, request: PushMessageRequest, context: grpc.ServicerContext
187
+ ) -> PushMessageResponse:
188
+ """Push one Message."""
189
+ # Initialize state connection
190
+ state = self.state_factory.state()
191
+
192
+ # Validate the token
193
+ run_id = state.get_run_id_by_token(request.token)
194
+ if run_id is None or not state.verify_token(run_id, request.token):
195
+ context.abort(
196
+ grpc.StatusCode.PERMISSION_DENIED,
197
+ "Invalid token.",
198
+ )
199
+ raise RuntimeError("This line should never be reached.")
200
+
201
+ # Save the message and context to the state
202
+ state.store_message(message_from_proto(request.message))
203
+
204
+ return PushMessageResponse()