flwr-nightly 1.20.0.dev20250712__py3-none-any.whl → 1.20.0.dev20250715__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/grpc_rere_client/connection.py +3 -1
- flwr/client/rest_client/connection.py +3 -1
- flwr/common/grpc.py +12 -1
- flwr/common/inflatable_utils.py +14 -7
- flwr/proto/appio_pb2.py +51 -0
- flwr/proto/appio_pb2.pyi +167 -0
- flwr/proto/appio_pb2_grpc.py +4 -0
- flwr/proto/appio_pb2_grpc.pyi +4 -0
- flwr/proto/clientappio_pb2.py +19 -11
- flwr/proto/clientappio_pb2.pyi +50 -12
- flwr/proto/clientappio_pb2_grpc.py +68 -0
- flwr/proto/clientappio_pb2_grpc.pyi +26 -0
- flwr/proto/fleet_pb2.py +14 -18
- flwr/proto/fleet_pb2.pyi +4 -19
- flwr/proto/serverappio_pb2.py +8 -31
- flwr/proto/serverappio_pb2.pyi +0 -152
- flwr/proto/serverappio_pb2_grpc.py +39 -38
- flwr/proto/serverappio_pb2_grpc.pyi +21 -20
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/serverapp/app.py +9 -11
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -13
- flwr/server/superlink/serverappio/serverappio_servicer.py +31 -33
- flwr/server/superlink/utils.py +3 -11
- flwr/supercore/grpc_health/__init__.py +22 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
- flwr/supercore/object_store/in_memory_object_store.py +31 -31
- flwr/supercore/object_store/object_store.py +16 -40
- flwr/supernode/runtime/run_clientapp.py +14 -4
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +48 -5
- flwr/supernode/start_client_internal.py +14 -0
- {flwr_nightly-1.20.0.dev20250712.dist-info → flwr_nightly-1.20.0.dev20250715.dist-info}/METADATA +2 -1
- {flwr_nightly-1.20.0.dev20250712.dist-info → flwr_nightly-1.20.0.dev20250715.dist-info}/RECORD +34 -28
- {flwr_nightly-1.20.0.dev20250712.dist-info → flwr_nightly-1.20.0.dev20250715.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.20.0.dev20250712.dist-info → flwr_nightly-1.20.0.dev20250715.dist-info}/entry_points.txt +0 -0
flwr/server/serverapp/app.py
CHANGED
@@ -55,12 +55,12 @@ from flwr.common.serde import (
|
|
55
55
|
)
|
56
56
|
from flwr.common.telemetry import EventType, event
|
57
57
|
from flwr.common.typing import RunNotRunningException, RunStatus
|
58
|
-
from flwr.proto.
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
PushServerAppOutputsRequest,
|
58
|
+
from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
|
59
|
+
PullAppInputsRequest,
|
60
|
+
PullAppInputsResponse,
|
61
|
+
PushAppOutputsRequest,
|
63
62
|
)
|
63
|
+
from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
|
64
64
|
from flwr.server.grid.grpc_grid import GrpcGrid
|
65
65
|
from flwr.server.run_serverapp import run as run_
|
66
66
|
|
@@ -125,9 +125,9 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
125
125
|
)
|
126
126
|
|
127
127
|
# Pull ServerAppInputs from LinkState
|
128
|
-
req =
|
128
|
+
req = PullAppInputsRequest()
|
129
129
|
log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
|
130
|
-
res:
|
130
|
+
res: PullAppInputsResponse = grid._stub.PullAppInputs(req)
|
131
131
|
if not res.HasField("run"):
|
132
132
|
sleep(3)
|
133
133
|
run_status = None
|
@@ -207,10 +207,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
207
207
|
# Send resulting context
|
208
208
|
context_proto = context_to_proto(updated_context)
|
209
209
|
log(DEBUG, "[flwr-serverapp] Will push ServerAppOutputs")
|
210
|
-
out_req =
|
211
|
-
|
212
|
-
)
|
213
|
-
_ = grid._stub.PushServerAppOutputs(out_req)
|
210
|
+
out_req = PushAppOutputsRequest(run_id=run.run_id, context=context_proto)
|
211
|
+
_ = grid._stub.PushAppOutputs(out_req)
|
214
212
|
|
215
213
|
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
216
214
|
except RunNotRunningException:
|
@@ -46,7 +46,6 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
46
46
|
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
47
47
|
ConfirmMessageReceivedRequest,
|
48
48
|
ConfirmMessageReceivedResponse,
|
49
|
-
ObjectIDs,
|
50
49
|
PullObjectRequest,
|
51
50
|
PullObjectResponse,
|
52
51
|
PushObjectRequest,
|
@@ -113,25 +112,22 @@ def pull_messages(
|
|
113
112
|
|
114
113
|
# Convert to Messages
|
115
114
|
msg_proto = []
|
116
|
-
|
115
|
+
trees = []
|
117
116
|
for msg in message_list:
|
118
117
|
try:
|
119
|
-
|
120
|
-
|
118
|
+
# Retrieve Message object tree from ObjectStore
|
121
119
|
msg_object_id = msg.metadata.message_id
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
)
|
120
|
+
obj_tree = store.get_object_tree(msg_object_id)
|
121
|
+
|
122
|
+
# Add Message and its object tree to the response
|
123
|
+
msg_proto.append(message_to_proto(msg))
|
124
|
+
trees.append(obj_tree)
|
127
125
|
except NoObjectInStoreError as e:
|
128
126
|
log(ERROR, e.message)
|
129
127
|
# Delete message ins from state
|
130
128
|
state.delete_messages(message_ins_ids={msg_object_id})
|
131
129
|
|
132
|
-
return PullMessagesResponse(
|
133
|
-
messages_list=msg_proto, objects_to_pull=objects_to_pull
|
134
|
-
)
|
130
|
+
return PullMessagesResponse(messages_list=msg_proto, message_object_trees=trees)
|
135
131
|
|
136
132
|
|
137
133
|
def push_messages(
|
@@ -287,6 +283,5 @@ def confirm_message_received(
|
|
287
283
|
|
288
284
|
# Delete the message object
|
289
285
|
store.delete(request.message_object_id)
|
290
|
-
store.delete_message_descendant_ids(request.message_object_id)
|
291
286
|
|
292
287
|
return ConfirmMessageReceivedResponse()
|
@@ -27,6 +27,7 @@ from flwr.common.inflatable import (
|
|
27
27
|
UnexpectedObjectContentError,
|
28
28
|
get_all_nested_objects,
|
29
29
|
get_object_tree,
|
30
|
+
iterate_object_tree,
|
30
31
|
no_object_id_recompute,
|
31
32
|
)
|
32
33
|
from flwr.common.logger import log
|
@@ -42,6 +43,16 @@ from flwr.common.serde import (
|
|
42
43
|
)
|
43
44
|
from flwr.common.typing import Fab, RunStatus
|
44
45
|
from flwr.proto import serverappio_pb2_grpc # pylint: disable=E0611
|
46
|
+
from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
|
47
|
+
PullAppInputsRequest,
|
48
|
+
PullAppInputsResponse,
|
49
|
+
PullAppMessagesRequest,
|
50
|
+
PullAppMessagesResponse,
|
51
|
+
PushAppMessagesRequest,
|
52
|
+
PushAppMessagesResponse,
|
53
|
+
PushAppOutputsRequest,
|
54
|
+
PushAppOutputsResponse,
|
55
|
+
)
|
45
56
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
46
57
|
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
47
58
|
SendAppHeartbeatRequest,
|
@@ -72,14 +83,6 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
72
83
|
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
73
84
|
GetNodesRequest,
|
74
85
|
GetNodesResponse,
|
75
|
-
PullResMessagesRequest,
|
76
|
-
PullResMessagesResponse,
|
77
|
-
PullServerAppInputsRequest,
|
78
|
-
PullServerAppInputsResponse,
|
79
|
-
PushInsMessagesRequest,
|
80
|
-
PushInsMessagesResponse,
|
81
|
-
PushServerAppOutputsRequest,
|
82
|
-
PushServerAppOutputsResponse,
|
83
86
|
)
|
84
87
|
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
85
88
|
from flwr.server.superlink.utils import abort_if
|
@@ -128,8 +131,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
128
131
|
return GetNodesResponse(nodes=nodes)
|
129
132
|
|
130
133
|
def PushMessages(
|
131
|
-
self, request:
|
132
|
-
) ->
|
134
|
+
self, request: PushAppMessagesRequest, context: grpc.ServicerContext
|
135
|
+
) -> PushAppMessagesResponse:
|
133
136
|
"""Push a set of Messages."""
|
134
137
|
log(DEBUG, "ServerAppIoServicer.PushMessages")
|
135
138
|
|
@@ -173,7 +176,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
173
176
|
# Store Message object to descendants mapping and preregister objects
|
174
177
|
objects_to_push = store_mapping_and_register_objects(store, request=request)
|
175
178
|
|
176
|
-
return
|
179
|
+
return PushAppMessagesResponse(
|
177
180
|
message_ids=[
|
178
181
|
str(message_id) if message_id else "" for message_id in message_ids
|
179
182
|
],
|
@@ -181,8 +184,8 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
181
184
|
)
|
182
185
|
|
183
186
|
def PullMessages( # pylint: disable=R0914
|
184
|
-
self, request:
|
185
|
-
) ->
|
187
|
+
self, request: PullAppMessagesRequest, context: grpc.ServicerContext
|
188
|
+
) -> PullAppMessagesResponse:
|
186
189
|
"""Pull a set of Messages."""
|
187
190
|
log(DEBUG, "ServerAppIoServicer.PullMessages")
|
188
191
|
|
@@ -209,12 +212,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
209
212
|
if msg_res.metadata.src_node_id == SUPERLINK_NODE_ID:
|
210
213
|
with no_object_id_recompute():
|
211
214
|
all_objects = get_all_nested_objects(msg_res)
|
212
|
-
descendants = list(all_objects.keys())[:-1]
|
213
|
-
message_obj_id = msg_res.metadata.message_id
|
214
|
-
# Store mapping
|
215
|
-
store.set_message_descendant_ids(
|
216
|
-
msg_object_id=message_obj_id, descendant_ids=descendants
|
217
|
-
)
|
218
215
|
# Preregister
|
219
216
|
store.preregister(request.run_id, get_object_tree(msg_res))
|
220
217
|
# Store objects
|
@@ -245,7 +242,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
245
242
|
|
246
243
|
try:
|
247
244
|
msg_object_id = msg.metadata.message_id
|
248
|
-
|
245
|
+
obj_tree = store.get_object_tree(msg_object_id)
|
246
|
+
descendants = [node.object_id for node in iterate_object_tree(obj_tree)]
|
247
|
+
descendants = descendants[:-1] # Exclude the message itself
|
249
248
|
# Add mapping of message object ID to its descendants
|
250
249
|
objects_to_pull[msg_object_id] = ObjectIDs(object_ids=descendants)
|
251
250
|
except NoObjectInStoreError as e:
|
@@ -253,7 +252,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
253
252
|
# Delete message ins from state
|
254
253
|
state.delete_messages(message_ins_ids={msg_object_id})
|
255
254
|
|
256
|
-
return
|
255
|
+
return PullAppMessagesResponse(
|
257
256
|
messages_list=messages_list, objects_to_pull=objects_to_pull
|
258
257
|
)
|
259
258
|
|
@@ -287,11 +286,11 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
287
286
|
|
288
287
|
raise ValueError(f"Found no FAB with hash: {request.hash_str}")
|
289
288
|
|
290
|
-
def
|
291
|
-
self, request:
|
292
|
-
) ->
|
289
|
+
def PullAppInputs(
|
290
|
+
self, request: PullAppInputsRequest, context: grpc.ServicerContext
|
291
|
+
) -> PullAppInputsResponse:
|
293
292
|
"""Pull ServerApp process inputs."""
|
294
|
-
log(DEBUG, "ServerAppIoServicer.
|
293
|
+
log(DEBUG, "ServerAppIoServicer.PullAppInputs")
|
295
294
|
# Init access to LinkState
|
296
295
|
state = self.state_factory.state()
|
297
296
|
|
@@ -301,7 +300,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
301
300
|
run_id = state.get_pending_run_id()
|
302
301
|
# If there's no pending run, return an empty response
|
303
302
|
if run_id is None:
|
304
|
-
return
|
303
|
+
return PullAppInputsResponse()
|
305
304
|
|
306
305
|
# Init access to Ffs
|
307
306
|
ffs = self.ffs_factory.ffs()
|
@@ -317,7 +316,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
317
316
|
# Update run status to STARTING
|
318
317
|
if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
|
319
318
|
log(INFO, "Starting run %d", run_id)
|
320
|
-
return
|
319
|
+
return PullAppInputsResponse(
|
321
320
|
context=context_to_proto(serverapp_ctxt),
|
322
321
|
run=run_to_proto(run),
|
323
322
|
fab=fab_to_proto(fab),
|
@@ -327,11 +326,11 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
327
326
|
# or if the status cannot be updated to STARTING
|
328
327
|
raise RuntimeError(f"Failed to start run {run_id}")
|
329
328
|
|
330
|
-
def
|
331
|
-
self, request:
|
332
|
-
) ->
|
329
|
+
def PushAppOutputs(
|
330
|
+
self, request: PushAppOutputsRequest, context: grpc.ServicerContext
|
331
|
+
) -> PushAppOutputsResponse:
|
333
332
|
"""Push ServerApp process outputs."""
|
334
|
-
log(DEBUG, "ServerAppIoServicer.
|
333
|
+
log(DEBUG, "ServerAppIoServicer.PushAppOutputs")
|
335
334
|
|
336
335
|
# Init state and store
|
337
336
|
state = self.state_factory.state()
|
@@ -347,7 +346,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
347
346
|
)
|
348
347
|
|
349
348
|
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
350
|
-
return
|
349
|
+
return PushAppOutputsResponse()
|
351
350
|
|
352
351
|
def UpdateRunStatus(
|
353
352
|
self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
|
@@ -511,7 +510,6 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
511
510
|
|
512
511
|
# Delete the message object
|
513
512
|
store.delete(request.message_object_id)
|
514
|
-
store.delete_message_descendant_ids(request.message_object_id)
|
515
513
|
|
516
514
|
return ConfirmMessageReceivedResponse()
|
517
515
|
|
flwr/server/superlink/utils.py
CHANGED
@@ -20,11 +20,10 @@ from typing import Optional, Union
|
|
20
20
|
import grpc
|
21
21
|
|
22
22
|
from flwr.common.constant import Status, SubStatus
|
23
|
-
from flwr.common.inflatable import iterate_object_tree
|
24
23
|
from flwr.common.typing import RunStatus
|
24
|
+
from flwr.proto.appio_pb2 import PushAppMessagesRequest # pylint: disable=E0611
|
25
25
|
from flwr.proto.fleet_pb2 import PushMessagesRequest # pylint: disable=E0611
|
26
26
|
from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
|
27
|
-
from flwr.proto.serverappio_pb2 import PushInsMessagesRequest # pylint: disable=E0611
|
28
27
|
from flwr.server.superlink.linkstate import LinkState
|
29
28
|
from flwr.supercore.object_store import ObjectStore
|
30
29
|
|
@@ -77,7 +76,7 @@ def abort_if(
|
|
77
76
|
|
78
77
|
|
79
78
|
def store_mapping_and_register_objects(
|
80
|
-
store: ObjectStore, request: Union[
|
79
|
+
store: ObjectStore, request: Union[PushAppMessagesRequest, PushMessagesRequest]
|
81
80
|
) -> dict[str, ObjectIDs]:
|
82
81
|
"""Store Message object to descendants mapping and preregister objects."""
|
83
82
|
if not request.messages_list:
|
@@ -90,17 +89,10 @@ def store_mapping_and_register_objects(
|
|
90
89
|
run_id = request.messages_list[0].metadata.run_id
|
91
90
|
|
92
91
|
for object_tree in request.message_object_trees:
|
93
|
-
all_object_ids = [obj.object_id for obj in iterate_object_tree(object_tree)]
|
94
|
-
msg_object_id, descendant_ids = all_object_ids[-1], all_object_ids[:-1]
|
95
|
-
# Store mapping
|
96
|
-
store.set_message_descendant_ids(
|
97
|
-
msg_object_id=msg_object_id, descendant_ids=descendant_ids
|
98
|
-
)
|
99
|
-
|
100
92
|
# Preregister
|
101
93
|
object_ids_just_registered = store.preregister(run_id, object_tree)
|
102
94
|
# Keep track of objects that need to be pushed
|
103
|
-
objects_to_push[
|
95
|
+
objects_to_push[object_tree.object_id] = ObjectIDs(
|
104
96
|
object_ids=object_ids_just_registered
|
105
97
|
)
|
106
98
|
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""GRPC health servicers."""
|
16
|
+
|
17
|
+
|
18
|
+
from .simple_health_servicer import SimpleHealthServicer
|
19
|
+
|
20
|
+
__all__ = [
|
21
|
+
"SimpleHealthServicer",
|
22
|
+
]
|
@@ -0,0 +1,38 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Simple gRPC health servicers."""
|
16
|
+
|
17
|
+
|
18
|
+
import grpc
|
19
|
+
|
20
|
+
# pylint: disable=E0611
|
21
|
+
from grpc_health.v1.health_pb2 import HealthCheckRequest, HealthCheckResponse
|
22
|
+
from grpc_health.v1.health_pb2_grpc import HealthServicer
|
23
|
+
|
24
|
+
# pylint: enable=E0611
|
25
|
+
|
26
|
+
|
27
|
+
class SimpleHealthServicer(HealthServicer): # type: ignore
|
28
|
+
"""A simple gRPC health servicer that always returns SERVING."""
|
29
|
+
|
30
|
+
def Check(
|
31
|
+
self, request: HealthCheckRequest, context: grpc.RpcContext
|
32
|
+
) -> HealthCheckResponse:
|
33
|
+
"""Return a HealthCheckResponse with SERVING status."""
|
34
|
+
return HealthCheckResponse(status=HealthCheckResponse.SERVING)
|
35
|
+
|
36
|
+
def Watch(self, request: HealthCheckRequest, context: grpc.RpcContext) -> None:
|
37
|
+
"""Watch the health status (not implemented)."""
|
38
|
+
context.abort(grpc.StatusCode.UNIMPLEMENTED, "Watch is not implemented")
|
@@ -20,7 +20,6 @@ from dataclasses import dataclass
|
|
20
20
|
from typing import Optional
|
21
21
|
|
22
22
|
from flwr.common.inflatable import (
|
23
|
-
get_object_children_ids_from_object_content,
|
24
23
|
get_object_id,
|
25
24
|
is_valid_sha256_hash,
|
26
25
|
iterate_object_tree,
|
@@ -37,6 +36,7 @@ class ObjectEntry:
|
|
37
36
|
|
38
37
|
content: bytes
|
39
38
|
is_available: bool
|
39
|
+
child_object_ids: list[str] # List of child object IDs
|
40
40
|
ref_count: int # Number of references (direct parents) to this object
|
41
41
|
runs: set[int] # Set of run IDs that used this object
|
42
42
|
|
@@ -70,6 +70,9 @@ class InMemoryObjectStore(ObjectStore):
|
|
70
70
|
self.store[obj_id] = ObjectEntry(
|
71
71
|
content=b"", # Initially empty content
|
72
72
|
is_available=False, # Initially not available
|
73
|
+
child_object_ids=[ # List of child object IDs
|
74
|
+
child.object_id for child in tree_node.children
|
75
|
+
],
|
73
76
|
ref_count=0, # Reference count starts at 0
|
74
77
|
runs={run_id}, # Start with the current run ID
|
75
78
|
)
|
@@ -102,6 +105,32 @@ class InMemoryObjectStore(ObjectStore):
|
|
102
105
|
|
103
106
|
return new_objects
|
104
107
|
|
108
|
+
def get_object_tree(self, object_id: str) -> ObjectTree:
|
109
|
+
"""Get the object tree for a given object ID."""
|
110
|
+
with self.lock_store:
|
111
|
+
# Raise an exception if there's no object with the given ID
|
112
|
+
if not (object_entry := self.store.get(object_id)):
|
113
|
+
raise NoObjectInStoreError(
|
114
|
+
f"Object with ID '{object_id}' was not pre-registered."
|
115
|
+
)
|
116
|
+
|
117
|
+
# Build the object trees of all children
|
118
|
+
try:
|
119
|
+
child_trees = [
|
120
|
+
self.get_object_tree(child_id)
|
121
|
+
for child_id in object_entry.child_object_ids
|
122
|
+
]
|
123
|
+
except NoObjectInStoreError as e:
|
124
|
+
# Raise an error if any child object is missing
|
125
|
+
# This indicates an integrity issue
|
126
|
+
raise NoObjectInStoreError(
|
127
|
+
f"Object tree for object ID '{object_id}' contains missing "
|
128
|
+
"children. This may indicate a corrupted object store."
|
129
|
+
) from e
|
130
|
+
|
131
|
+
# Create and return the ObjectTree for the current object
|
132
|
+
return ObjectTree(object_id=object_id, children=child_trees)
|
133
|
+
|
105
134
|
def put(self, object_id: str, object_content: bytes) -> None:
|
106
135
|
"""Put an object into the store."""
|
107
136
|
if self.verify:
|
@@ -128,29 +157,6 @@ class InMemoryObjectStore(ObjectStore):
|
|
128
157
|
self.store[object_id].content = object_content
|
129
158
|
self.store[object_id].is_available = True
|
130
159
|
|
131
|
-
def set_message_descendant_ids(
|
132
|
-
self, msg_object_id: str, descendant_ids: list[str]
|
133
|
-
) -> None:
|
134
|
-
"""Store the mapping from a ``Message`` object ID to the object IDs of its
|
135
|
-
descendants."""
|
136
|
-
with self.lock_msg_mapping:
|
137
|
-
self.msg_descendant_objects_mapping[msg_object_id] = descendant_ids
|
138
|
-
|
139
|
-
def get_message_descendant_ids(self, msg_object_id: str) -> list[str]:
|
140
|
-
"""Retrieve the object IDs of all descendants of a given Message."""
|
141
|
-
with self.lock_msg_mapping:
|
142
|
-
if msg_object_id not in self.msg_descendant_objects_mapping:
|
143
|
-
raise NoObjectInStoreError(
|
144
|
-
f"No message registered in Object Store with ID '{msg_object_id}'. "
|
145
|
-
"Mapping to descendants could not be found."
|
146
|
-
)
|
147
|
-
return self.msg_descendant_objects_mapping[msg_object_id]
|
148
|
-
|
149
|
-
def delete_message_descendant_ids(self, msg_object_id: str) -> None:
|
150
|
-
"""Delete the mapping from a ``Message`` object ID to its descendants."""
|
151
|
-
with self.lock_msg_mapping:
|
152
|
-
self.msg_descendant_objects_mapping.pop(msg_object_id, None)
|
153
|
-
|
154
160
|
def get(self, object_id: str) -> Optional[bytes]:
|
155
161
|
"""Get an object from the store."""
|
156
162
|
with self.lock_store:
|
@@ -177,10 +183,7 @@ class InMemoryObjectStore(ObjectStore):
|
|
177
183
|
self.run_objects_mapping[run_id].discard(object_id)
|
178
184
|
|
179
185
|
# Decrease the reference count of its children
|
180
|
-
|
181
|
-
object_entry.content
|
182
|
-
)
|
183
|
-
for child_id in children_ids:
|
186
|
+
for child_id in object_entry.child_object_ids:
|
184
187
|
self.store[child_id].ref_count -= 1
|
185
188
|
|
186
189
|
# Recursively try to delete the child object
|
@@ -205,9 +208,6 @@ class InMemoryObjectStore(ObjectStore):
|
|
205
208
|
# Delete the message object and its unreferenced descendants
|
206
209
|
self.delete(object_id)
|
207
210
|
|
208
|
-
# Delete the message's descendants mapping
|
209
|
-
self.delete_message_descendant_ids(object_id)
|
210
|
-
|
211
211
|
# Remove the run from the mapping
|
212
212
|
del self.run_objects_mapping[run_id]
|
213
213
|
|
@@ -60,6 +60,22 @@ class ObjectStore(abc.ABC):
|
|
60
60
|
in the `ObjectStore`, or were preregistered but are not yet available.
|
61
61
|
"""
|
62
62
|
|
63
|
+
@abc.abstractmethod
|
64
|
+
def get_object_tree(self, object_id: str) -> ObjectTree:
|
65
|
+
"""Get the object tree for a given object ID.
|
66
|
+
|
67
|
+
Parameters
|
68
|
+
----------
|
69
|
+
object_id : str
|
70
|
+
The ID of the object for which to retrieve the object tree.
|
71
|
+
|
72
|
+
Returns
|
73
|
+
-------
|
74
|
+
ObjectTree
|
75
|
+
An ObjectTree representing the hierarchical structure of the object with
|
76
|
+
the given ID and its descendants.
|
77
|
+
"""
|
78
|
+
|
63
79
|
@abc.abstractmethod
|
64
80
|
def put(self, object_id: str, object_content: bytes) -> None:
|
65
81
|
"""Put an object into the store.
|
@@ -126,46 +142,6 @@ class ObjectStore(abc.ABC):
|
|
126
142
|
This method should remove all objects from the store.
|
127
143
|
"""
|
128
144
|
|
129
|
-
@abc.abstractmethod
|
130
|
-
def set_message_descendant_ids(
|
131
|
-
self, msg_object_id: str, descendant_ids: list[str]
|
132
|
-
) -> None:
|
133
|
-
"""Store the mapping from a ``Message`` object ID to the object IDs of its
|
134
|
-
descendants.
|
135
|
-
|
136
|
-
Parameters
|
137
|
-
----------
|
138
|
-
msg_object_id : str
|
139
|
-
The object ID of the ``Message``.
|
140
|
-
descendant_ids : list[str]
|
141
|
-
A list of object IDs representing all descendant objects of the ``Message``.
|
142
|
-
"""
|
143
|
-
|
144
|
-
@abc.abstractmethod
|
145
|
-
def get_message_descendant_ids(self, msg_object_id: str) -> list[str]:
|
146
|
-
"""Retrieve the object IDs of all descendants of a given ``Message``.
|
147
|
-
|
148
|
-
Parameters
|
149
|
-
----------
|
150
|
-
msg_object_id : str
|
151
|
-
The object ID of the ``Message``.
|
152
|
-
|
153
|
-
Returns
|
154
|
-
-------
|
155
|
-
list[str]
|
156
|
-
A list of object IDs of all descendant objects of the ``Message``.
|
157
|
-
"""
|
158
|
-
|
159
|
-
@abc.abstractmethod
|
160
|
-
def delete_message_descendant_ids(self, msg_object_id: str) -> None:
|
161
|
-
"""Delete the mapping from a ``Message`` object ID to its descendants.
|
162
|
-
|
163
|
-
Parameters
|
164
|
-
----------
|
165
|
-
msg_object_id : str
|
166
|
-
The object ID of the ``Message``.
|
167
|
-
"""
|
168
|
-
|
169
145
|
@abc.abstractmethod
|
170
146
|
def __contains__(self, object_id: str) -> bool:
|
171
147
|
"""Check if an object_id is in the store.
|
@@ -50,8 +50,11 @@ from flwr.proto.clientappio_pb2 import (
|
|
50
50
|
GetRunIdsWithPendingMessagesResponse,
|
51
51
|
PullClientAppInputsRequest,
|
52
52
|
PullClientAppInputsResponse,
|
53
|
+
PullMessageRequest,
|
54
|
+
PullMessageResponse,
|
53
55
|
PushClientAppOutputsRequest,
|
54
56
|
PushClientAppOutputsResponse,
|
57
|
+
PushMessageRequest,
|
55
58
|
RequestTokenRequest,
|
56
59
|
RequestTokenResponse,
|
57
60
|
)
|
@@ -199,10 +202,14 @@ def pull_clientappinputs(
|
|
199
202
|
masked_token = mask_string(token)
|
200
203
|
log(INFO, "[flwr-clientapp] Pull `ClientAppInputs` for token %s", masked_token)
|
201
204
|
try:
|
205
|
+
# Pull Message
|
206
|
+
res_msg: PullMessageResponse = stub.PullMessage(PullMessageRequest(token=token))
|
207
|
+
message = message_from_proto(res_msg.message)
|
208
|
+
|
209
|
+
# Pull Context, Run and (optional) FAB
|
202
210
|
res: PullClientAppInputsResponse = stub.PullClientAppInputs(
|
203
211
|
PullClientAppInputsRequest(token=token)
|
204
212
|
)
|
205
|
-
message = message_from_proto(res.message)
|
206
213
|
context = context_from_proto(res.context)
|
207
214
|
run = run_from_proto(res.run)
|
208
215
|
fab = fab_from_proto(res.fab) if res.fab else None
|
@@ -224,10 +231,13 @@ def push_clientappoutputs(
|
|
224
231
|
proto_context = context_to_proto(context)
|
225
232
|
|
226
233
|
try:
|
234
|
+
|
235
|
+
# Push Message
|
236
|
+
_ = stub.PushMessage(PushMessageRequest(token=token, message=proto_message))
|
237
|
+
|
238
|
+
# Push Context
|
227
239
|
res: PushClientAppOutputsResponse = stub.PushClientAppOutputs(
|
228
|
-
PushClientAppOutputsRequest(
|
229
|
-
token=token, message=proto_message, context=proto_context
|
230
|
-
)
|
240
|
+
PushClientAppOutputsRequest(token=token, context=proto_context)
|
231
241
|
)
|
232
242
|
return res
|
233
243
|
except grpc.RpcError as e:
|
@@ -39,8 +39,12 @@ from flwr.proto.clientappio_pb2 import ( # pylint: disable=E0401
|
|
39
39
|
GetRunIdsWithPendingMessagesResponse,
|
40
40
|
PullClientAppInputsRequest,
|
41
41
|
PullClientAppInputsResponse,
|
42
|
+
PullMessageRequest,
|
43
|
+
PullMessageResponse,
|
42
44
|
PushClientAppOutputsRequest,
|
43
45
|
PushClientAppOutputsResponse,
|
46
|
+
PushMessageRequest,
|
47
|
+
PushMessageResponse,
|
44
48
|
RequestTokenRequest,
|
45
49
|
RequestTokenResponse,
|
46
50
|
)
|
@@ -119,14 +123,12 @@ class ClientAppIoServicer(clientappio_pb2_grpc.ClientAppIoServicer):
|
|
119
123
|
)
|
120
124
|
raise RuntimeError("This line should never be reached.")
|
121
125
|
|
122
|
-
# Retrieve
|
123
|
-
message = state.get_messages(run_ids=[run_id], is_reply=False)[0]
|
126
|
+
# Retrieve context, run and fab for this run
|
124
127
|
context = cast(Context, state.get_context(run_id))
|
125
128
|
run = cast(Run, state.get_run(run_id))
|
126
129
|
fab = Fab(run.fab_hash, ffs.get(run.fab_hash)[0]) # type: ignore
|
127
130
|
|
128
131
|
return PullClientAppInputsResponse(
|
129
|
-
message=message_to_proto(message),
|
130
132
|
context=context_to_proto(context),
|
131
133
|
run=run_to_proto(run),
|
132
134
|
fab=fab_to_proto(fab),
|
@@ -150,8 +152,7 @@ class ClientAppIoServicer(clientappio_pb2_grpc.ClientAppIoServicer):
|
|
150
152
|
)
|
151
153
|
raise RuntimeError("This line should never be reached.")
|
152
154
|
|
153
|
-
# Save the
|
154
|
-
state.store_message(message_from_proto(request.message))
|
155
|
+
# Save the context to the state
|
155
156
|
state.store_context(context_from_proto(request.context))
|
156
157
|
|
157
158
|
# Remove the token to make the run eligible for processing
|
@@ -159,3 +160,45 @@ class ClientAppIoServicer(clientappio_pb2_grpc.ClientAppIoServicer):
|
|
159
160
|
state.delete_token(run_id)
|
160
161
|
|
161
162
|
return PushClientAppOutputsResponse()
|
163
|
+
|
164
|
+
def PullMessage(
|
165
|
+
self, request: PullMessageRequest, context: grpc.ServicerContext
|
166
|
+
) -> PullMessageResponse:
|
167
|
+
"""Pull one Message."""
|
168
|
+
# Initialize state and ffs connection
|
169
|
+
state = self.state_factory.state()
|
170
|
+
|
171
|
+
# Validate the token
|
172
|
+
run_id = state.get_run_id_by_token(request.token)
|
173
|
+
if run_id is None or not state.verify_token(run_id, request.token):
|
174
|
+
context.abort(
|
175
|
+
grpc.StatusCode.PERMISSION_DENIED,
|
176
|
+
"Invalid token.",
|
177
|
+
)
|
178
|
+
raise RuntimeError("This line should never be reached.")
|
179
|
+
|
180
|
+
# Retrieve message, context, run and fab for this run
|
181
|
+
message = state.get_messages(run_ids=[run_id], is_reply=False)[0]
|
182
|
+
|
183
|
+
return PullMessageResponse(message=message_to_proto(message))
|
184
|
+
|
185
|
+
def PushMessage(
|
186
|
+
self, request: PushMessageRequest, context: grpc.ServicerContext
|
187
|
+
) -> PushMessageResponse:
|
188
|
+
"""Push one Message."""
|
189
|
+
# Initialize state connection
|
190
|
+
state = self.state_factory.state()
|
191
|
+
|
192
|
+
# Validate the token
|
193
|
+
run_id = state.get_run_id_by_token(request.token)
|
194
|
+
if run_id is None or not state.verify_token(run_id, request.token):
|
195
|
+
context.abort(
|
196
|
+
grpc.StatusCode.PERMISSION_DENIED,
|
197
|
+
"Invalid token.",
|
198
|
+
)
|
199
|
+
raise RuntimeError("This line should never be reached.")
|
200
|
+
|
201
|
+
# Save the message and context to the state
|
202
|
+
state.store_message(message_from_proto(request.message))
|
203
|
+
|
204
|
+
return PushMessageResponse()
|