flwr 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/build.py +94 -59
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +15 -36
- flwr/cli/new/new.py +12 -4
- flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
- flwr/cli/new/templates/app/README.md.tpl +5 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +25 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
- flwr/cli/run/run.py +48 -49
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +38 -5
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_adapter_client/connection.py +15 -8
- flwr/client/grpc_rere_client/connection.py +142 -97
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/client/rest_client/connection.py +176 -103
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +29 -3
- flwr/common/constant.py +39 -8
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit/exit_code.py +16 -1
- flwr/common/exit_handlers.py +30 -0
- flwr/common/grpc.py +12 -1
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_protobuf_utils.py +141 -0
- flwr/common/inflatable_utils.py +508 -0
- flwr/common/message.py +110 -242
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/array.py +402 -0
- flwr/common/record/arraychunk.py +59 -0
- flwr/common/record/arrayrecord.py +103 -225
- flwr/common/record/configrecord.py +59 -4
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/metricrecord.py +55 -4
- flwr/common/record/recorddict.py +69 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +59 -211
- flwr/common/serde_utils.py +175 -0
- flwr/common/typing.py +5 -3
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +28 -185
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/appio_pb2.py +43 -0
- flwr/proto/appio_pb2.pyi +151 -0
- flwr/proto/appio_pb2_grpc.py +4 -0
- flwr/proto/appio_pb2_grpc.pyi +4 -0
- flwr/proto/clientappio_pb2.py +12 -19
- flwr/proto/clientappio_pb2.pyi +23 -101
- flwr/proto/clientappio_pb2_grpc.py +269 -28
- flwr/proto/clientappio_pb2_grpc.pyi +114 -20
- flwr/proto/fleet_pb2.py +24 -27
- flwr/proto/fleet_pb2.pyi +19 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +9 -23
- flwr/proto/serverappio_pb2.pyi +0 -110
- flwr/proto/serverappio_pb2_grpc.py +177 -72
- flwr/proto/serverappio_pb2_grpc.pyi +75 -33
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +1 -1
- flwr/server/app.py +69 -187
- flwr/server/compat/app_utils.py +50 -28
- flwr/server/fleet_event_log_interceptor.py +6 -2
- flwr/server/grid/grpc_grid.py +148 -41
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/serverapp/app.py +45 -17
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +21 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
- flwr/server/superlink/fleet/message_handler/message_handler.py +130 -19
- flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -13
- flwr/server/superlink/fleet/vce/vce_api.py +6 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
- flwr/server/superlink/linkstate/linkstate.py +53 -20
- flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
- flwr/server/superlink/linkstate/utils.py +33 -29
- flwr/server/superlink/serverappio/serverappio_grpc.py +4 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +230 -84
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
- flwr/server/superlink/utils.py +9 -2
- flwr/server/utils/validator.py +2 -2
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/app.py +25 -0
- flwr/simulation/run_simulation.py +17 -0
- flwr/supercore/__init__.py +15 -0
- flwr/{server/superlink → supercore}/ffs/__init__.py +2 -0
- flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
- flwr/supercore/grpc_health/__init__.py +22 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
- flwr/supercore/license_plugin/__init__.py +22 -0
- flwr/supercore/license_plugin/license_plugin.py +26 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +170 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/supercore/object_store/utils.py +43 -0
- flwr/supercore/scheduler/__init__.py +22 -0
- flwr/supercore/scheduler/plugin.py +71 -0
- flwr/{client/nodestate/nodestate.py → supercore/utils.py} +14 -13
- flwr/superexec/deployment.py +7 -4
- flwr/superexec/exec_event_log_interceptor.py +8 -4
- flwr/superexec/exec_grpc.py +25 -5
- flwr/superexec/exec_license_interceptor.py +82 -0
- flwr/superexec/exec_servicer.py +135 -24
- flwr/superexec/exec_user_auth_interceptor.py +45 -8
- flwr/superexec/executor.py +5 -1
- flwr/superexec/simulation.py +8 -3
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +0 -7
- flwr/supernode/cli/__init__.py +24 -0
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -19
- flwr/supernode/cli/flwr_clientapp.py +88 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +199 -0
- flwr/supernode/nodestate/nodestate.py +227 -0
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +135 -89
- flwr/supernode/scheduler/__init__.py +22 -0
- flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +22 -0
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +303 -0
- flwr/supernode/start_client_internal.py +589 -0
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/METADATA +6 -4
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/RECORD +171 -123
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +1 -1
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +2 -2
- flwr/client/clientapp/clientappio_servicer.py +0 -244
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
- /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
flwr/server/grid/grpc_grid.py
CHANGED
|
@@ -22,26 +22,51 @@ from typing import Optional, cast
|
|
|
22
22
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
25
|
-
from flwr.
|
|
25
|
+
from flwr.app.error import Error
|
|
26
|
+
from flwr.common import Message, Metadata, RecordDict, now
|
|
26
27
|
from flwr.common.constant import (
|
|
27
28
|
SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
|
28
29
|
SUPERLINK_NODE_ID,
|
|
30
|
+
ErrorCode,
|
|
31
|
+
MessageType,
|
|
29
32
|
)
|
|
30
33
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
|
34
|
+
from flwr.common.inflatable import (
|
|
35
|
+
InflatableObject,
|
|
36
|
+
get_all_nested_objects,
|
|
37
|
+
get_object_tree,
|
|
38
|
+
iterate_object_tree,
|
|
39
|
+
no_object_id_recompute,
|
|
40
|
+
)
|
|
41
|
+
from flwr.common.inflatable_protobuf_utils import (
|
|
42
|
+
make_pull_object_fn_protobuf,
|
|
43
|
+
make_push_object_fn_protobuf,
|
|
44
|
+
)
|
|
45
|
+
from flwr.common.inflatable_utils import (
|
|
46
|
+
ObjectUnavailableError,
|
|
47
|
+
inflate_object_from_contents,
|
|
48
|
+
pull_objects,
|
|
49
|
+
push_objects,
|
|
50
|
+
)
|
|
31
51
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
52
|
+
from flwr.common.message import make_message, remove_content_from_message
|
|
32
53
|
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
|
33
|
-
from flwr.common.serde import
|
|
54
|
+
from flwr.common.serde import message_to_proto, run_from_proto
|
|
34
55
|
from flwr.common.typing import Run
|
|
35
|
-
from flwr.proto.
|
|
56
|
+
from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
|
|
57
|
+
PullAppMessagesRequest,
|
|
58
|
+
PullAppMessagesResponse,
|
|
59
|
+
PushAppMessagesRequest,
|
|
60
|
+
PushAppMessagesResponse,
|
|
61
|
+
)
|
|
62
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
63
|
+
ConfirmMessageReceivedRequest,
|
|
64
|
+
)
|
|
36
65
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
37
66
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
38
67
|
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
39
68
|
GetNodesRequest,
|
|
40
69
|
GetNodesResponse,
|
|
41
|
-
PullResMessagesRequest,
|
|
42
|
-
PullResMessagesResponse,
|
|
43
|
-
PushInsMessagesRequest,
|
|
44
|
-
PushInsMessagesResponse,
|
|
45
70
|
)
|
|
46
71
|
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
|
|
47
72
|
|
|
@@ -163,7 +188,7 @@ class GrpcGrid(Grid):
|
|
|
163
188
|
def _check_message(self, message: Message) -> None:
|
|
164
189
|
# Check if the message is valid
|
|
165
190
|
if not (
|
|
166
|
-
message.metadata.message_id
|
|
191
|
+
message.metadata.message_id != ""
|
|
167
192
|
and message.metadata.reply_to_message_id == ""
|
|
168
193
|
and message.metadata.ttl > 0
|
|
169
194
|
):
|
|
@@ -198,6 +223,39 @@ class GrpcGrid(Grid):
|
|
|
198
223
|
)
|
|
199
224
|
return [node.node_id for node in res.nodes]
|
|
200
225
|
|
|
226
|
+
def _try_push_messages(self, run_id: int, messages: Iterable[Message]) -> list[str]:
|
|
227
|
+
"""Push all messages and its associated objects."""
|
|
228
|
+
# Prepare all Messages to be sent in a single request
|
|
229
|
+
proto_messages = []
|
|
230
|
+
object_trees = []
|
|
231
|
+
all_objects: dict[str, InflatableObject] = {}
|
|
232
|
+
for msg in messages:
|
|
233
|
+
proto_messages.append(message_to_proto(remove_content_from_message(msg)))
|
|
234
|
+
all_objects.update(get_all_nested_objects(msg))
|
|
235
|
+
object_trees.append(get_object_tree(msg))
|
|
236
|
+
del msg
|
|
237
|
+
|
|
238
|
+
# Call GrpcServerAppIoStub method
|
|
239
|
+
res: PushAppMessagesResponse = self._stub.PushMessages(
|
|
240
|
+
PushAppMessagesRequest(
|
|
241
|
+
messages_list=proto_messages,
|
|
242
|
+
run_id=run_id,
|
|
243
|
+
message_object_trees=object_trees,
|
|
244
|
+
)
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Push objects
|
|
248
|
+
push_objects(
|
|
249
|
+
all_objects,
|
|
250
|
+
push_object_fn=make_push_object_fn_protobuf(
|
|
251
|
+
push_object_protobuf=self._stub.PushObject,
|
|
252
|
+
node=self.node,
|
|
253
|
+
run_id=run_id,
|
|
254
|
+
),
|
|
255
|
+
object_ids_to_push=set(res.objects_to_push),
|
|
256
|
+
)
|
|
257
|
+
return cast(list[str], res.message_ids)
|
|
258
|
+
|
|
201
259
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
202
260
|
"""Push messages to specified node IDs.
|
|
203
261
|
|
|
@@ -206,57 +264,106 @@ class GrpcGrid(Grid):
|
|
|
206
264
|
"""
|
|
207
265
|
# Construct Messages
|
|
208
266
|
run_id = cast(Run, self._run).run_id
|
|
209
|
-
|
|
210
|
-
for msg in messages:
|
|
211
|
-
# Populate metadata
|
|
212
|
-
msg.metadata.__dict__["_run_id"] = run_id
|
|
213
|
-
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
|
214
|
-
# Check message
|
|
215
|
-
self._check_message(msg)
|
|
216
|
-
# Convert to proto
|
|
217
|
-
msg_proto = message_to_proto(msg)
|
|
218
|
-
# Add to list
|
|
219
|
-
message_proto_list.append(msg_proto)
|
|
220
|
-
|
|
267
|
+
message_ids: list[str] = []
|
|
221
268
|
try:
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
"passed to `push_messages`). This could be due to a malformed "
|
|
234
|
-
"message.",
|
|
235
|
-
)
|
|
236
|
-
return list(res.message_ids)
|
|
269
|
+
with no_object_id_recompute():
|
|
270
|
+
for msg in messages:
|
|
271
|
+
# Populate metadata
|
|
272
|
+
msg.metadata.__dict__["_run_id"] = run_id
|
|
273
|
+
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
|
274
|
+
msg.metadata.__dict__["_message_id"] = msg.object_id
|
|
275
|
+
# Check message
|
|
276
|
+
self._check_message(msg)
|
|
277
|
+
# Try pushing messages and their objects
|
|
278
|
+
message_ids = self._try_push_messages(run_id, messages)
|
|
279
|
+
|
|
237
280
|
except grpc.RpcError as e:
|
|
238
281
|
if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
|
|
239
282
|
log(ERROR, ERROR_MESSAGE_PUSH_MESSAGES_RESOURCE_EXHAUSTED)
|
|
240
283
|
return []
|
|
241
284
|
raise
|
|
242
285
|
|
|
286
|
+
if None in message_ids:
|
|
287
|
+
log(
|
|
288
|
+
WARNING,
|
|
289
|
+
"Not all messages could be pushed to the SuperLink. The returned "
|
|
290
|
+
"list has `None` for those messages (the order is preserved as "
|
|
291
|
+
"passed to `push_messages`). This could be due to a malformed "
|
|
292
|
+
"message.",
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
return message_ids
|
|
296
|
+
|
|
243
297
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
244
298
|
"""Pull messages based on message IDs.
|
|
245
299
|
|
|
246
300
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
247
301
|
set of given message IDs.
|
|
248
302
|
"""
|
|
303
|
+
run_id = cast(Run, self._run).run_id
|
|
249
304
|
try:
|
|
250
305
|
# Pull Messages
|
|
251
|
-
res:
|
|
252
|
-
|
|
306
|
+
res: PullAppMessagesResponse = self._stub.PullMessages(
|
|
307
|
+
PullAppMessagesRequest(
|
|
253
308
|
message_ids=message_ids,
|
|
254
|
-
run_id=
|
|
309
|
+
run_id=run_id,
|
|
255
310
|
)
|
|
256
311
|
)
|
|
257
|
-
#
|
|
258
|
-
|
|
259
|
-
|
|
312
|
+
# Pull Messages from store
|
|
313
|
+
inflated_msgs: list[Message] = []
|
|
314
|
+
for msg_proto, msg_tree in zip(res.messages_list, res.message_object_trees):
|
|
315
|
+
msg_id = msg_proto.metadata.message_id
|
|
316
|
+
try:
|
|
317
|
+
all_object_contents = pull_objects(
|
|
318
|
+
object_ids=[
|
|
319
|
+
tree.object_id for tree in iterate_object_tree(msg_tree)
|
|
320
|
+
],
|
|
321
|
+
pull_object_fn=make_pull_object_fn_protobuf(
|
|
322
|
+
pull_object_protobuf=self._stub.PullObject,
|
|
323
|
+
node=self.node,
|
|
324
|
+
run_id=run_id,
|
|
325
|
+
),
|
|
326
|
+
)
|
|
327
|
+
except ObjectUnavailableError as e:
|
|
328
|
+
# An ObjectUnavailableError indicates that the object is not yet
|
|
329
|
+
# available. If this point has been reached, it means that the
|
|
330
|
+
# Grid has tried to pull the object for the maximum number of times
|
|
331
|
+
# or for the maximum time allowed, so we return an inflated message
|
|
332
|
+
# with an error
|
|
333
|
+
inflated_msgs.append(
|
|
334
|
+
make_message(
|
|
335
|
+
metadata=Metadata(
|
|
336
|
+
run_id=run_id,
|
|
337
|
+
message_id="",
|
|
338
|
+
src_node_id=self.node.node_id,
|
|
339
|
+
dst_node_id=self.node.node_id,
|
|
340
|
+
message_type=MessageType.SYSTEM,
|
|
341
|
+
group_id="",
|
|
342
|
+
ttl=0,
|
|
343
|
+
reply_to_message_id=msg_proto.metadata.reply_to_message_id,
|
|
344
|
+
created_at=now().timestamp(),
|
|
345
|
+
),
|
|
346
|
+
error=Error(
|
|
347
|
+
code=ErrorCode.MESSAGE_UNAVAILABLE, reason=(str(e))
|
|
348
|
+
),
|
|
349
|
+
)
|
|
350
|
+
)
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
# Confirm that the message has been received
|
|
354
|
+
self._stub.ConfirmMessageReceived(
|
|
355
|
+
ConfirmMessageReceivedRequest(
|
|
356
|
+
node=self.node, run_id=run_id, message_object_id=msg_id
|
|
357
|
+
)
|
|
358
|
+
)
|
|
359
|
+
message = cast(
|
|
360
|
+
Message, inflate_object_from_contents(msg_id, all_object_contents)
|
|
361
|
+
)
|
|
362
|
+
message.metadata.__dict__["_message_id"] = msg_id
|
|
363
|
+
inflated_msgs.append(message)
|
|
364
|
+
|
|
365
|
+
return inflated_msgs
|
|
366
|
+
|
|
260
367
|
except grpc.RpcError as e:
|
|
261
368
|
if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
|
|
262
369
|
log(ERROR, ERROR_MESSAGE_PULL_MESSAGES_RESOURCE_EXHAUSTED)
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import time
|
|
19
19
|
from collections.abc import Iterable
|
|
20
20
|
from typing import Optional, cast
|
|
21
|
-
from uuid import
|
|
21
|
+
from uuid import uuid4
|
|
22
22
|
|
|
23
23
|
from flwr.common import Message, RecordDict
|
|
24
24
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
|
@@ -56,7 +56,7 @@ class InMemoryGrid(Grid):
|
|
|
56
56
|
def _check_message(self, message: Message) -> None:
|
|
57
57
|
# Check if the message is valid
|
|
58
58
|
if not (
|
|
59
|
-
message.metadata.message_id
|
|
59
|
+
message.metadata.message_id != ""
|
|
60
60
|
and message.metadata.reply_to_message_id == ""
|
|
61
61
|
and message.metadata.ttl > 0
|
|
62
62
|
and message.metadata.delivered_at == ""
|
|
@@ -111,6 +111,7 @@ class InMemoryGrid(Grid):
|
|
|
111
111
|
# Populate metadata
|
|
112
112
|
msg.metadata.__dict__["_run_id"] = cast(Run, self._run).run_id
|
|
113
113
|
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
|
114
|
+
msg.metadata.__dict__["_message_id"] = str(uuid4())
|
|
114
115
|
# Check message
|
|
115
116
|
self._check_message(msg)
|
|
116
117
|
# Store in state
|
|
@@ -126,12 +127,12 @@ class InMemoryGrid(Grid):
|
|
|
126
127
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
127
128
|
set of given message IDs.
|
|
128
129
|
"""
|
|
129
|
-
msg_ids =
|
|
130
|
+
msg_ids = set(message_ids)
|
|
130
131
|
# Pull Messages
|
|
131
132
|
message_res_list = self.state.get_message_res(message_ids=msg_ids)
|
|
132
133
|
# Get IDs of Messages these replies are for
|
|
133
134
|
message_ins_ids_to_delete = {
|
|
134
|
-
|
|
135
|
+
msg_res.metadata.reply_to_message_id for msg_res in message_res_list
|
|
135
136
|
}
|
|
136
137
|
# Delete
|
|
137
138
|
self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
flwr/server/serverapp/app.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import argparse
|
|
19
|
+
import gc
|
|
19
20
|
from logging import DEBUG, ERROR, INFO
|
|
20
21
|
from pathlib import Path
|
|
21
22
|
from queue import Queue
|
|
@@ -38,6 +39,7 @@ from flwr.common.constant import (
|
|
|
38
39
|
SubStatus,
|
|
39
40
|
)
|
|
40
41
|
from flwr.common.exit import ExitCode, flwr_exit
|
|
42
|
+
from flwr.common.heartbeat import HeartbeatSender, get_grpc_app_heartbeat_fn
|
|
41
43
|
from flwr.common.logger import (
|
|
42
44
|
log,
|
|
43
45
|
mirror_output_to_queue,
|
|
@@ -54,12 +56,12 @@ from flwr.common.serde import (
|
|
|
54
56
|
)
|
|
55
57
|
from flwr.common.telemetry import EventType, event
|
|
56
58
|
from flwr.common.typing import RunNotRunningException, RunStatus
|
|
57
|
-
from flwr.proto.
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
PushServerAppOutputsRequest,
|
|
59
|
+
from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
|
|
60
|
+
PullAppInputsRequest,
|
|
61
|
+
PullAppInputsResponse,
|
|
62
|
+
PushAppOutputsRequest,
|
|
62
63
|
)
|
|
64
|
+
from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
|
|
63
65
|
from flwr.server.grid.grpc_grid import GrpcGrid
|
|
64
66
|
from flwr.server.run_serverapp import run as run_
|
|
65
67
|
|
|
@@ -106,24 +108,28 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
106
108
|
certificates: Optional[bytes] = None,
|
|
107
109
|
) -> None:
|
|
108
110
|
"""Run Flower ServerApp process."""
|
|
109
|
-
grid = GrpcGrid(
|
|
110
|
-
serverappio_service_address=serverappio_api_address,
|
|
111
|
-
root_certificates=certificates,
|
|
112
|
-
)
|
|
113
|
-
|
|
114
111
|
# Resolve directory where FABs are installed
|
|
115
112
|
flwr_dir_ = get_flwr_dir(flwr_dir)
|
|
116
113
|
log_uploader = None
|
|
117
114
|
success = True
|
|
118
115
|
hash_run_id = None
|
|
119
116
|
run_status = None
|
|
117
|
+
heartbeat_sender = None
|
|
118
|
+
grid = None
|
|
119
|
+
context = None
|
|
120
120
|
while True:
|
|
121
121
|
|
|
122
122
|
try:
|
|
123
|
+
# Initialize the GrpcGrid
|
|
124
|
+
grid = GrpcGrid(
|
|
125
|
+
serverappio_service_address=serverappio_api_address,
|
|
126
|
+
root_certificates=certificates,
|
|
127
|
+
)
|
|
128
|
+
|
|
123
129
|
# Pull ServerAppInputs from LinkState
|
|
124
|
-
req =
|
|
130
|
+
req = PullAppInputsRequest()
|
|
125
131
|
log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
|
|
126
|
-
res:
|
|
132
|
+
res: PullAppInputsResponse = grid._stub.PullAppInputs(req)
|
|
127
133
|
if not res.HasField("run"):
|
|
128
134
|
sleep(3)
|
|
129
135
|
run_status = None
|
|
@@ -182,6 +188,16 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
182
188
|
event_details={"run-id-hash": hash_run_id},
|
|
183
189
|
)
|
|
184
190
|
|
|
191
|
+
# Set up heartbeat sender
|
|
192
|
+
heartbeat_fn = get_grpc_app_heartbeat_fn(
|
|
193
|
+
grid._stub,
|
|
194
|
+
run.run_id,
|
|
195
|
+
failure_message="Heartbeat failed unexpectedly. The SuperLink could "
|
|
196
|
+
"not find the provided run ID, or the run status is invalid.",
|
|
197
|
+
)
|
|
198
|
+
heartbeat_sender = HeartbeatSender(heartbeat_fn)
|
|
199
|
+
heartbeat_sender.start()
|
|
200
|
+
|
|
185
201
|
# Load and run the ServerApp with the Grid
|
|
186
202
|
updated_context = run_(
|
|
187
203
|
grid=grid,
|
|
@@ -193,10 +209,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
193
209
|
# Send resulting context
|
|
194
210
|
context_proto = context_to_proto(updated_context)
|
|
195
211
|
log(DEBUG, "[flwr-serverapp] Will push ServerAppOutputs")
|
|
196
|
-
out_req =
|
|
197
|
-
|
|
198
|
-
)
|
|
199
|
-
_ = grid._stub.PushServerAppOutputs(out_req)
|
|
212
|
+
out_req = PushAppOutputsRequest(run_id=run.run_id, context=context_proto)
|
|
213
|
+
_ = grid._stub.PushAppOutputs(out_req)
|
|
200
214
|
|
|
201
215
|
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
|
202
216
|
except RunNotRunningException:
|
|
@@ -213,19 +227,33 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
213
227
|
success = False
|
|
214
228
|
|
|
215
229
|
finally:
|
|
230
|
+
# Stop heartbeat sender
|
|
231
|
+
if heartbeat_sender:
|
|
232
|
+
heartbeat_sender.stop()
|
|
233
|
+
heartbeat_sender = None
|
|
234
|
+
|
|
216
235
|
# Stop log uploader for this run and upload final logs
|
|
217
236
|
if log_uploader:
|
|
218
237
|
stop_log_uploader(log_queue, log_uploader)
|
|
219
238
|
log_uploader = None
|
|
220
239
|
|
|
221
240
|
# Update run status
|
|
222
|
-
if run_status:
|
|
241
|
+
if run_status and grid:
|
|
223
242
|
run_status_proto = run_status_to_proto(run_status)
|
|
224
243
|
grid._stub.UpdateRunStatus(
|
|
225
244
|
UpdateRunStatusRequest(
|
|
226
245
|
run_id=run.run_id, run_status=run_status_proto
|
|
227
246
|
)
|
|
228
247
|
)
|
|
248
|
+
|
|
249
|
+
# Close the Grpc connection
|
|
250
|
+
if grid:
|
|
251
|
+
grid.close()
|
|
252
|
+
|
|
253
|
+
# Clean up the Context
|
|
254
|
+
context = None
|
|
255
|
+
gc.collect()
|
|
256
|
+
|
|
229
257
|
event(
|
|
230
258
|
EventType.FLWR_SERVERAPP_RUN_LEAVE,
|
|
231
259
|
event_details={"run-id-hash": hash_run_id, "success": success},
|
|
@@ -35,11 +35,16 @@ from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
|
|
|
35
35
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
36
36
|
CreateNodeRequest,
|
|
37
37
|
DeleteNodeRequest,
|
|
38
|
-
PingRequest,
|
|
39
38
|
PullMessagesRequest,
|
|
40
39
|
PushMessagesRequest,
|
|
41
40
|
)
|
|
42
41
|
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
42
|
+
from flwr.proto.heartbeat_pb2 import SendNodeHeartbeatRequest # pylint: disable=E0611
|
|
43
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
44
|
+
ConfirmMessageReceivedRequest,
|
|
45
|
+
PullObjectRequest,
|
|
46
|
+
PushObjectRequest,
|
|
47
|
+
)
|
|
43
48
|
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
|
|
44
49
|
|
|
45
50
|
from ..grpc_rere.fleet_servicer import FleetServicer
|
|
@@ -81,8 +86,10 @@ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer, FleetService
|
|
|
81
86
|
return _handle(request, context, CreateNodeRequest, self.CreateNode)
|
|
82
87
|
if request.grpc_message_name == DeleteNodeRequest.__qualname__:
|
|
83
88
|
return _handle(request, context, DeleteNodeRequest, self.DeleteNode)
|
|
84
|
-
if request.grpc_message_name ==
|
|
85
|
-
return _handle(
|
|
89
|
+
if request.grpc_message_name == SendNodeHeartbeatRequest.__qualname__:
|
|
90
|
+
return _handle(
|
|
91
|
+
request, context, SendNodeHeartbeatRequest, self.SendNodeHeartbeat
|
|
92
|
+
)
|
|
86
93
|
if request.grpc_message_name == GetRunRequest.__qualname__:
|
|
87
94
|
return _handle(request, context, GetRunRequest, self.GetRun)
|
|
88
95
|
if request.grpc_message_name == GetFabRequest.__qualname__:
|
|
@@ -91,4 +98,15 @@ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer, FleetService
|
|
|
91
98
|
return _handle(request, context, PullMessagesRequest, self.PullMessages)
|
|
92
99
|
if request.grpc_message_name == PushMessagesRequest.__qualname__:
|
|
93
100
|
return _handle(request, context, PushMessagesRequest, self.PushMessages)
|
|
101
|
+
if request.grpc_message_name == PushObjectRequest.__qualname__:
|
|
102
|
+
return _handle(request, context, PushObjectRequest, self.PushObject)
|
|
103
|
+
if request.grpc_message_name == PullObjectRequest.__qualname__:
|
|
104
|
+
return _handle(request, context, PullObjectRequest, self.PullObject)
|
|
105
|
+
if request.grpc_message_name == ConfirmMessageReceivedRequest.__qualname__:
|
|
106
|
+
return _handle(
|
|
107
|
+
request,
|
|
108
|
+
context,
|
|
109
|
+
ConfirmMessageReceivedRequest,
|
|
110
|
+
self.ConfirmMessageReceived,
|
|
111
|
+
)
|
|
94
112
|
raise ValueError(f"Invalid grpc_message_name: {request.grpc_message_name}")
|
|
@@ -20,6 +20,7 @@ from logging import DEBUG, INFO
|
|
|
20
20
|
import grpc
|
|
21
21
|
from google.protobuf.json_format import MessageToDict
|
|
22
22
|
|
|
23
|
+
from flwr.common.inflatable import UnexpectedObjectContentError
|
|
23
24
|
from flwr.common.logger import log
|
|
24
25
|
from flwr.common.typing import InvalidRunStatusException
|
|
25
26
|
from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611
|
|
@@ -29,34 +30,53 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
29
30
|
CreateNodeResponse,
|
|
30
31
|
DeleteNodeRequest,
|
|
31
32
|
DeleteNodeResponse,
|
|
32
|
-
PingRequest,
|
|
33
|
-
PingResponse,
|
|
34
33
|
PullMessagesRequest,
|
|
35
34
|
PullMessagesResponse,
|
|
36
35
|
PushMessagesRequest,
|
|
37
36
|
PushMessagesResponse,
|
|
38
37
|
)
|
|
38
|
+
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
39
|
+
SendNodeHeartbeatRequest,
|
|
40
|
+
SendNodeHeartbeatResponse,
|
|
41
|
+
)
|
|
42
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
43
|
+
ConfirmMessageReceivedRequest,
|
|
44
|
+
ConfirmMessageReceivedResponse,
|
|
45
|
+
PullObjectRequest,
|
|
46
|
+
PullObjectResponse,
|
|
47
|
+
PushObjectRequest,
|
|
48
|
+
PushObjectResponse,
|
|
49
|
+
)
|
|
39
50
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
40
|
-
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
41
51
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
42
52
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
43
53
|
from flwr.server.superlink.utils import abort_grpc_context
|
|
54
|
+
from flwr.supercore.ffs import FfsFactory
|
|
55
|
+
from flwr.supercore.object_store import ObjectStoreFactory
|
|
44
56
|
|
|
45
57
|
|
|
46
58
|
class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
47
59
|
"""Fleet API servicer."""
|
|
48
60
|
|
|
49
61
|
def __init__(
|
|
50
|
-
self,
|
|
62
|
+
self,
|
|
63
|
+
state_factory: LinkStateFactory,
|
|
64
|
+
ffs_factory: FfsFactory,
|
|
65
|
+
objectstore_factory: ObjectStoreFactory,
|
|
51
66
|
) -> None:
|
|
52
67
|
self.state_factory = state_factory
|
|
53
68
|
self.ffs_factory = ffs_factory
|
|
69
|
+
self.objectstore_factory = objectstore_factory
|
|
54
70
|
|
|
55
71
|
def CreateNode(
|
|
56
72
|
self, request: CreateNodeRequest, context: grpc.ServicerContext
|
|
57
73
|
) -> CreateNodeResponse:
|
|
58
74
|
"""."""
|
|
59
|
-
log(
|
|
75
|
+
log(
|
|
76
|
+
INFO,
|
|
77
|
+
"[Fleet.CreateNode] Request heartbeat_interval=%s",
|
|
78
|
+
request.heartbeat_interval,
|
|
79
|
+
)
|
|
60
80
|
log(DEBUG, "[Fleet.CreateNode] Request: %s", MessageToDict(request))
|
|
61
81
|
response = message_handler.create_node(
|
|
62
82
|
request=request,
|
|
@@ -77,10 +97,12 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
77
97
|
state=self.state_factory.state(),
|
|
78
98
|
)
|
|
79
99
|
|
|
80
|
-
def
|
|
100
|
+
def SendNodeHeartbeat(
|
|
101
|
+
self, request: SendNodeHeartbeatRequest, context: grpc.ServicerContext
|
|
102
|
+
) -> SendNodeHeartbeatResponse:
|
|
81
103
|
"""."""
|
|
82
|
-
log(DEBUG, "[Fleet.
|
|
83
|
-
return message_handler.
|
|
104
|
+
log(DEBUG, "[Fleet.SendNodeHeartbeat] Request: %s", MessageToDict(request))
|
|
105
|
+
return message_handler.send_node_heartbeat(
|
|
84
106
|
request=request,
|
|
85
107
|
state=self.state_factory.state(),
|
|
86
108
|
)
|
|
@@ -94,6 +116,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
94
116
|
return message_handler.pull_messages(
|
|
95
117
|
request=request,
|
|
96
118
|
state=self.state_factory.state(),
|
|
119
|
+
store=self.objectstore_factory.store(),
|
|
97
120
|
)
|
|
98
121
|
|
|
99
122
|
def PushMessages(
|
|
@@ -113,6 +136,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
113
136
|
res = message_handler.push_messages(
|
|
114
137
|
request=request,
|
|
115
138
|
state=self.state_factory.state(),
|
|
139
|
+
store=self.objectstore_factory.store(),
|
|
116
140
|
)
|
|
117
141
|
except InvalidRunStatusException as e:
|
|
118
142
|
abort_grpc_context(e.message, context)
|
|
@@ -129,6 +153,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
129
153
|
res = message_handler.get_run(
|
|
130
154
|
request=request,
|
|
131
155
|
state=self.state_factory.state(),
|
|
156
|
+
store=self.objectstore_factory.store(),
|
|
132
157
|
)
|
|
133
158
|
except InvalidRunStatusException as e:
|
|
134
159
|
abort_grpc_context(e.message, context)
|
|
@@ -145,6 +170,75 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
145
170
|
request=request,
|
|
146
171
|
ffs=self.ffs_factory.ffs(),
|
|
147
172
|
state=self.state_factory.state(),
|
|
173
|
+
store=self.objectstore_factory.store(),
|
|
174
|
+
)
|
|
175
|
+
except InvalidRunStatusException as e:
|
|
176
|
+
abort_grpc_context(e.message, context)
|
|
177
|
+
|
|
178
|
+
return res
|
|
179
|
+
|
|
180
|
+
def PushObject(
|
|
181
|
+
self, request: PushObjectRequest, context: grpc.ServicerContext
|
|
182
|
+
) -> PushObjectResponse:
|
|
183
|
+
"""Push an object to the ObjectStore."""
|
|
184
|
+
log(
|
|
185
|
+
DEBUG,
|
|
186
|
+
"[ServerAppIoServicer.PushObject] Push Object with object_id=%s",
|
|
187
|
+
request.object_id,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
# Insert in Store
|
|
192
|
+
res = message_handler.push_object(
|
|
193
|
+
request=request,
|
|
194
|
+
state=self.state_factory.state(),
|
|
195
|
+
store=self.objectstore_factory.store(),
|
|
196
|
+
)
|
|
197
|
+
except InvalidRunStatusException as e:
|
|
198
|
+
abort_grpc_context(e.message, context)
|
|
199
|
+
except UnexpectedObjectContentError as e:
|
|
200
|
+
# Object content is not valid
|
|
201
|
+
context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
|
|
202
|
+
|
|
203
|
+
return res
|
|
204
|
+
|
|
205
|
+
def PullObject(
|
|
206
|
+
self, request: PullObjectRequest, context: grpc.ServicerContext
|
|
207
|
+
) -> PullObjectResponse:
|
|
208
|
+
"""Pull an object from the ObjectStore."""
|
|
209
|
+
log(
|
|
210
|
+
DEBUG,
|
|
211
|
+
"[ServerAppIoServicer.PullObject] Pull Object with object_id=%s",
|
|
212
|
+
request.object_id,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
try:
|
|
216
|
+
# Fetch from store
|
|
217
|
+
res = message_handler.pull_object(
|
|
218
|
+
request=request,
|
|
219
|
+
state=self.state_factory.state(),
|
|
220
|
+
store=self.objectstore_factory.store(),
|
|
221
|
+
)
|
|
222
|
+
except InvalidRunStatusException as e:
|
|
223
|
+
abort_grpc_context(e.message, context)
|
|
224
|
+
|
|
225
|
+
return res
|
|
226
|
+
|
|
227
|
+
def ConfirmMessageReceived(
|
|
228
|
+
self, request: ConfirmMessageReceivedRequest, context: grpc.ServicerContext
|
|
229
|
+
) -> ConfirmMessageReceivedResponse:
|
|
230
|
+
"""Confirm message received."""
|
|
231
|
+
log(
|
|
232
|
+
DEBUG,
|
|
233
|
+
"[Fleet.ConfirmMessageReceived] Message with ID '%s' has been received",
|
|
234
|
+
request.message_object_id,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
res = message_handler.confirm_message_received(
|
|
239
|
+
request=request,
|
|
240
|
+
state=self.state_factory.state(),
|
|
241
|
+
store=self.objectstore_factory.store(),
|
|
148
242
|
)
|
|
149
243
|
except InvalidRunStatusException as e:
|
|
150
244
|
abort_grpc_context(e.message, context)
|
|
@@ -81,12 +81,9 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
81
81
|
metadata sent by the node. Continue RPC call if node is authenticated, else,
|
|
82
82
|
terminate RPC call by setting context to abort.
|
|
83
83
|
"""
|
|
84
|
-
#
|
|
84
|
+
# Only apply to Fleet service
|
|
85
85
|
if not handler_call_details.method.startswith("/flwr.proto.Fleet/"):
|
|
86
|
-
return
|
|
87
|
-
"This request should be sent to a different service.",
|
|
88
|
-
grpc.StatusCode.FAILED_PRECONDITION,
|
|
89
|
-
)
|
|
86
|
+
return continuation(handler_call_details)
|
|
90
87
|
|
|
91
88
|
state = self.state_factory.state()
|
|
92
89
|
metadata_dict = dict(handler_call_details.invocation_metadata)
|