flwr 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/build.py +82 -57
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +15 -36
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +10 -18
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +31 -5
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_adapter_client/connection.py +4 -4
- flwr/client/grpc_rere_client/connection.py +130 -60
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/client/rest_client/connection.py +173 -67
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +29 -3
- flwr/common/constant.py +36 -7
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit_handlers.py +30 -0
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_grpc_utils.py +99 -0
- flwr/common/inflatable_rest_utils.py +99 -0
- flwr/common/inflatable_utils.py +341 -0
- flwr/common/message.py +110 -242
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/array.py +323 -0
- flwr/common/record/arrayrecord.py +103 -225
- flwr/common/record/configrecord.py +59 -4
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/metricrecord.py +55 -4
- flwr/common/record/recorddict.py +69 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +59 -183
- flwr/common/serde_utils.py +175 -0
- flwr/common/typing.py +5 -3
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +19 -159
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/fleet_pb2.py +32 -27
- flwr/proto/fleet_pb2.pyi +49 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +32 -23
- flwr/proto/serverappio_pb2.pyi +45 -3
- flwr/proto/serverappio_pb2_grpc.py +138 -34
- flwr/proto/serverappio_pb2_grpc.pyi +54 -13
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +1 -1
- flwr/server/app.py +68 -186
- flwr/server/compat/app_utils.py +50 -28
- flwr/server/fleet_event_log_interceptor.py +2 -2
- flwr/server/grid/grpc_grid.py +104 -34
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/serverapp/app.py +18 -0
- flwr/server/superlink/ffs/__init__.py +2 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +13 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +101 -7
- flwr/server/superlink/fleet/message_handler/message_handler.py +135 -18
- flwr/server/superlink/fleet/rest_rere/rest_api.py +72 -11
- flwr/server/superlink/fleet/vce/vce_api.py +6 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
- flwr/server/superlink/linkstate/linkstate.py +53 -20
- flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
- flwr/server/superlink/linkstate/utils.py +33 -29
- flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
- flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
- flwr/server/superlink/simulation/simulationio_servicer.py +25 -1
- flwr/server/superlink/utils.py +44 -2
- flwr/server/utils/validator.py +2 -2
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/app.py +17 -0
- flwr/supercore/__init__.py +15 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +192 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/superexec/deployment.py +6 -2
- flwr/superexec/exec_event_log_interceptor.py +4 -4
- flwr/superexec/exec_grpc.py +7 -3
- flwr/superexec/exec_servicer.py +125 -23
- flwr/superexec/exec_user_auth_interceptor.py +37 -8
- flwr/superexec/executor.py +4 -0
- flwr/superexec/simulation.py +7 -1
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +0 -7
- flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +7 -14
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -12
- flwr/supernode/cli/flwr_clientapp.py +81 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
- flwr/supernode/nodestate/nodestate.py +212 -0
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +25 -56
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +24 -0
- flwr/supernode/start_client_internal.py +491 -0
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/METADATA +5 -4
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/RECORD +141 -108
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
- {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
- /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
- /flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +0 -0
|
@@ -15,47 +15,61 @@
|
|
|
15
15
|
"""Contextmanager for a gRPC request-response channel to the Flower server."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import random
|
|
19
|
-
import threading
|
|
20
18
|
from collections.abc import Iterator, Sequence
|
|
21
19
|
from contextlib import contextmanager
|
|
22
20
|
from copy import copy
|
|
23
|
-
from logging import ERROR
|
|
21
|
+
from logging import DEBUG, ERROR
|
|
24
22
|
from pathlib import Path
|
|
25
23
|
from typing import Callable, Optional, Union, cast
|
|
26
24
|
|
|
27
25
|
import grpc
|
|
28
26
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
29
27
|
|
|
30
|
-
from flwr.
|
|
28
|
+
from flwr.app.metadata import Metadata
|
|
31
29
|
from flwr.client.message_handler.message_handler import validate_out_message
|
|
32
30
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
33
|
-
from flwr.common.constant import
|
|
34
|
-
PING_BASE_MULTIPLIER,
|
|
35
|
-
PING_CALL_TIMEOUT,
|
|
36
|
-
PING_DEFAULT_INTERVAL,
|
|
37
|
-
PING_RANDOM_RANGE,
|
|
38
|
-
)
|
|
31
|
+
from flwr.common.constant import HEARTBEAT_CALL_TIMEOUT, HEARTBEAT_DEFAULT_INTERVAL
|
|
39
32
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
|
33
|
+
from flwr.common.heartbeat import HeartbeatSender
|
|
34
|
+
from flwr.common.inflatable import (
|
|
35
|
+
get_all_nested_objects,
|
|
36
|
+
get_object_tree,
|
|
37
|
+
no_object_id_recompute,
|
|
38
|
+
)
|
|
39
|
+
from flwr.common.inflatable_grpc_utils import (
|
|
40
|
+
make_pull_object_fn_grpc,
|
|
41
|
+
make_push_object_fn_grpc,
|
|
42
|
+
)
|
|
43
|
+
from flwr.common.inflatable_utils import (
|
|
44
|
+
inflate_object_from_contents,
|
|
45
|
+
pull_objects,
|
|
46
|
+
push_objects,
|
|
47
|
+
)
|
|
40
48
|
from flwr.common.logger import log
|
|
41
|
-
from flwr.common.message import Message,
|
|
42
|
-
from flwr.common.retry_invoker import RetryInvoker
|
|
49
|
+
from flwr.common.message import Message, remove_content_from_message
|
|
50
|
+
from flwr.common.retry_invoker import RetryInvoker, _wrap_stub
|
|
43
51
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
44
52
|
generate_key_pairs,
|
|
45
53
|
)
|
|
46
|
-
from flwr.common.serde import
|
|
54
|
+
from flwr.common.serde import message_to_proto, run_from_proto
|
|
47
55
|
from flwr.common.typing import Fab, Run, RunNotRunningException
|
|
48
56
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
49
57
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
50
58
|
CreateNodeRequest,
|
|
51
59
|
DeleteNodeRequest,
|
|
52
|
-
PingRequest,
|
|
53
|
-
PingResponse,
|
|
54
60
|
PullMessagesRequest,
|
|
55
61
|
PullMessagesResponse,
|
|
56
62
|
PushMessagesRequest,
|
|
63
|
+
PushMessagesResponse,
|
|
57
64
|
)
|
|
58
65
|
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
|
66
|
+
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
67
|
+
SendNodeHeartbeatRequest,
|
|
68
|
+
SendNodeHeartbeatResponse,
|
|
69
|
+
)
|
|
70
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
71
|
+
ConfirmMessageReceivedRequest,
|
|
72
|
+
)
|
|
59
73
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
60
74
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
61
75
|
|
|
@@ -78,10 +92,10 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
78
92
|
tuple[
|
|
79
93
|
Callable[[], Optional[Message]],
|
|
80
94
|
Callable[[Message], None],
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
95
|
+
Callable[[], Optional[int]],
|
|
96
|
+
Callable[[], None],
|
|
97
|
+
Callable[[int], Run],
|
|
98
|
+
Callable[[str, int], Fab],
|
|
85
99
|
]
|
|
86
100
|
]:
|
|
87
101
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -151,8 +165,6 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
151
165
|
stub = adapter_cls(channel)
|
|
152
166
|
metadata: Optional[Metadata] = None
|
|
153
167
|
node: Optional[Node] = None
|
|
154
|
-
ping_thread: Optional[threading.Thread] = None
|
|
155
|
-
ping_stop_event = threading.Event()
|
|
156
168
|
|
|
157
169
|
def _should_giveup_fn(e: Exception) -> bool:
|
|
158
170
|
if e.code() == grpc.StatusCode.PERMISSION_DENIED: # type: ignore
|
|
@@ -165,46 +177,58 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
165
177
|
# If the status code is PERMISSION_DENIED, additionally raise RunNotRunningException
|
|
166
178
|
retry_invoker.should_giveup = _should_giveup_fn
|
|
167
179
|
|
|
180
|
+
# Wrap stub
|
|
181
|
+
_wrap_stub(stub, retry_invoker)
|
|
168
182
|
###########################################################################
|
|
169
|
-
#
|
|
183
|
+
# send_node_heartbeat/create_node/delete_node/receive/send/get_run functions
|
|
170
184
|
###########################################################################
|
|
171
185
|
|
|
172
|
-
def
|
|
186
|
+
def send_node_heartbeat() -> bool:
|
|
173
187
|
# Get Node
|
|
174
188
|
if node is None:
|
|
175
189
|
log(ERROR, "Node instance missing")
|
|
176
|
-
return
|
|
190
|
+
return False
|
|
177
191
|
|
|
178
|
-
# Construct the
|
|
179
|
-
req =
|
|
192
|
+
# Construct the heartbeat request
|
|
193
|
+
req = SendNodeHeartbeatRequest(
|
|
194
|
+
node=node, heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
|
|
195
|
+
)
|
|
180
196
|
|
|
181
197
|
# Call FleetAPI
|
|
182
|
-
|
|
198
|
+
try:
|
|
199
|
+
res: SendNodeHeartbeatResponse = stub.SendNodeHeartbeat(
|
|
200
|
+
req, timeout=HEARTBEAT_CALL_TIMEOUT
|
|
201
|
+
)
|
|
202
|
+
except grpc.RpcError as e:
|
|
203
|
+
status_code = e.code()
|
|
204
|
+
if status_code == grpc.StatusCode.UNAVAILABLE:
|
|
205
|
+
return False
|
|
206
|
+
if status_code == grpc.StatusCode.DEADLINE_EXCEEDED:
|
|
207
|
+
return False
|
|
208
|
+
raise
|
|
183
209
|
|
|
184
210
|
# Check if success
|
|
185
211
|
if not res.success:
|
|
186
|
-
raise RuntimeError(
|
|
212
|
+
raise RuntimeError(
|
|
213
|
+
"Heartbeat failed unexpectedly. The SuperLink does not "
|
|
214
|
+
"recognize this SuperNode."
|
|
215
|
+
)
|
|
216
|
+
return True
|
|
187
217
|
|
|
188
|
-
|
|
189
|
-
rd = random.uniform(*PING_RANDOM_RANGE)
|
|
190
|
-
next_interval: float = PING_DEFAULT_INTERVAL - PING_CALL_TIMEOUT
|
|
191
|
-
next_interval *= PING_BASE_MULTIPLIER + rd
|
|
192
|
-
if not ping_stop_event.is_set():
|
|
193
|
-
ping_stop_event.wait(next_interval)
|
|
218
|
+
heartbeat_sender = HeartbeatSender(send_node_heartbeat)
|
|
194
219
|
|
|
195
220
|
def create_node() -> Optional[int]:
|
|
196
221
|
"""Set create_node."""
|
|
197
222
|
# Call FleetAPI
|
|
198
|
-
create_node_request = CreateNodeRequest(
|
|
199
|
-
|
|
200
|
-
stub.CreateNode,
|
|
201
|
-
request=create_node_request,
|
|
223
|
+
create_node_request = CreateNodeRequest(
|
|
224
|
+
heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
|
|
202
225
|
)
|
|
226
|
+
create_node_response = stub.CreateNode(request=create_node_request)
|
|
203
227
|
|
|
204
|
-
# Remember the node and the
|
|
205
|
-
nonlocal node
|
|
228
|
+
# Remember the node and start the heartbeat sender
|
|
229
|
+
nonlocal node
|
|
206
230
|
node = cast(Node, create_node_response.node)
|
|
207
|
-
|
|
231
|
+
heartbeat_sender.start()
|
|
208
232
|
return node.node_id
|
|
209
233
|
|
|
210
234
|
def delete_node() -> None:
|
|
@@ -215,12 +239,12 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
215
239
|
log(ERROR, "Node instance missing")
|
|
216
240
|
return
|
|
217
241
|
|
|
218
|
-
# Stop the
|
|
219
|
-
|
|
242
|
+
# Stop the heartbeat sender
|
|
243
|
+
heartbeat_sender.stop()
|
|
220
244
|
|
|
221
245
|
# Call FleetAPI
|
|
222
246
|
delete_node_request = DeleteNodeRequest(node=node)
|
|
223
|
-
|
|
247
|
+
stub.DeleteNode(request=delete_node_request)
|
|
224
248
|
|
|
225
249
|
# Cleanup
|
|
226
250
|
node = None
|
|
@@ -234,9 +258,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
234
258
|
|
|
235
259
|
# Request instructions (message) from server
|
|
236
260
|
request = PullMessagesRequest(node=node)
|
|
237
|
-
response: PullMessagesResponse =
|
|
238
|
-
stub.PullMessages, request=request
|
|
239
|
-
)
|
|
261
|
+
response: PullMessagesResponse = stub.PullMessages(request=request)
|
|
240
262
|
|
|
241
263
|
# Get the current Messages
|
|
242
264
|
message_proto = (
|
|
@@ -250,7 +272,33 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
250
272
|
message_proto = None
|
|
251
273
|
|
|
252
274
|
# Construct the Message
|
|
253
|
-
in_message =
|
|
275
|
+
in_message: Optional[Message] = None
|
|
276
|
+
|
|
277
|
+
if message_proto:
|
|
278
|
+
msg_id = message_proto.metadata.message_id
|
|
279
|
+
run_id = message_proto.metadata.run_id
|
|
280
|
+
all_object_contents = pull_objects(
|
|
281
|
+
list(response.objects_to_pull[msg_id].object_ids) + [msg_id],
|
|
282
|
+
pull_object_fn=make_pull_object_fn_grpc(
|
|
283
|
+
pull_object_grpc=stub.PullObject,
|
|
284
|
+
node=node,
|
|
285
|
+
run_id=run_id,
|
|
286
|
+
),
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Confirm that the message has been received
|
|
290
|
+
stub.ConfirmMessageReceived(
|
|
291
|
+
ConfirmMessageReceivedRequest(
|
|
292
|
+
node=node, run_id=run_id, message_object_id=msg_id
|
|
293
|
+
)
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
in_message = cast(
|
|
297
|
+
Message, inflate_object_from_contents(msg_id, all_object_contents)
|
|
298
|
+
)
|
|
299
|
+
# The deflated message doesn't contain the message_id (its own object_id)
|
|
300
|
+
# Inject
|
|
301
|
+
in_message.metadata.__dict__["_message_id"] = msg_id
|
|
254
302
|
|
|
255
303
|
# Remember `metadata` of the in message
|
|
256
304
|
nonlocal metadata
|
|
@@ -272,15 +320,43 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
272
320
|
log(ERROR, "No current message")
|
|
273
321
|
return
|
|
274
322
|
|
|
323
|
+
# Set message_id
|
|
324
|
+
message.metadata.__dict__["_message_id"] = message.object_id
|
|
275
325
|
# Validate out message
|
|
276
326
|
if not validate_out_message(message, metadata):
|
|
277
327
|
log(ERROR, "Invalid out message")
|
|
278
328
|
return
|
|
279
329
|
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
330
|
+
with no_object_id_recompute():
|
|
331
|
+
# Get all nested objects
|
|
332
|
+
all_objects = get_all_nested_objects(message)
|
|
333
|
+
object_tree = get_object_tree(message)
|
|
334
|
+
|
|
335
|
+
# Serialize Message
|
|
336
|
+
message_proto = message_to_proto(
|
|
337
|
+
message=remove_content_from_message(message)
|
|
338
|
+
)
|
|
339
|
+
request = PushMessagesRequest(
|
|
340
|
+
node=node,
|
|
341
|
+
messages_list=[message_proto],
|
|
342
|
+
message_object_trees=[object_tree],
|
|
343
|
+
)
|
|
344
|
+
response: PushMessagesResponse = stub.PushMessages(request=request)
|
|
345
|
+
|
|
346
|
+
if response.objects_to_push:
|
|
347
|
+
objs_to_push = set(
|
|
348
|
+
response.objects_to_push[message.object_id].object_ids
|
|
349
|
+
)
|
|
350
|
+
push_objects(
|
|
351
|
+
all_objects,
|
|
352
|
+
push_object_fn=make_push_object_fn_grpc(
|
|
353
|
+
push_object_grpc=stub.PushObject,
|
|
354
|
+
node=node,
|
|
355
|
+
run_id=message.metadata.run_id,
|
|
356
|
+
),
|
|
357
|
+
object_ids_to_push=objs_to_push,
|
|
358
|
+
)
|
|
359
|
+
log(DEBUG, "Pushed %s objects to servicer.", len(objs_to_push))
|
|
284
360
|
|
|
285
361
|
# Cleanup
|
|
286
362
|
metadata = None
|
|
@@ -288,10 +364,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
288
364
|
def get_run(run_id: int) -> Run:
|
|
289
365
|
# Call FleetAPI
|
|
290
366
|
get_run_request = GetRunRequest(node=node, run_id=run_id)
|
|
291
|
-
get_run_response: GetRunResponse =
|
|
292
|
-
stub.GetRun,
|
|
293
|
-
request=get_run_request,
|
|
294
|
-
)
|
|
367
|
+
get_run_response: GetRunResponse = stub.GetRun(request=get_run_request)
|
|
295
368
|
|
|
296
369
|
# Return fab_id and fab_version
|
|
297
370
|
return run_from_proto(get_run_response.run)
|
|
@@ -299,10 +372,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
299
372
|
def get_fab(fab_hash: str, run_id: int) -> Fab:
|
|
300
373
|
# Call FleetAPI
|
|
301
374
|
get_fab_request = GetFabRequest(node=node, hash_str=fab_hash, run_id=run_id)
|
|
302
|
-
get_fab_response: GetFabResponse =
|
|
303
|
-
stub.GetFab,
|
|
304
|
-
request=get_fab_request,
|
|
305
|
-
)
|
|
375
|
+
get_fab_response: GetFabResponse = stub.GetFab(request=get_fab_request)
|
|
306
376
|
|
|
307
377
|
return Fab(get_fab_response.fab.hash_str, get_fab_response.fab.content)
|
|
308
378
|
|
|
@@ -38,8 +38,6 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
38
38
|
CreateNodeResponse,
|
|
39
39
|
DeleteNodeRequest,
|
|
40
40
|
DeleteNodeResponse,
|
|
41
|
-
PingRequest,
|
|
42
|
-
PingResponse,
|
|
43
41
|
PullMessagesRequest,
|
|
44
42
|
PullMessagesResponse,
|
|
45
43
|
PushMessagesRequest,
|
|
@@ -47,6 +45,18 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
47
45
|
)
|
|
48
46
|
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
49
47
|
from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub
|
|
48
|
+
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
49
|
+
SendNodeHeartbeatRequest,
|
|
50
|
+
SendNodeHeartbeatResponse,
|
|
51
|
+
)
|
|
52
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
53
|
+
ConfirmMessageReceivedRequest,
|
|
54
|
+
ConfirmMessageReceivedResponse,
|
|
55
|
+
PullObjectRequest,
|
|
56
|
+
PullObjectResponse,
|
|
57
|
+
PushObjectRequest,
|
|
58
|
+
PushObjectResponse,
|
|
59
|
+
)
|
|
50
60
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
51
61
|
|
|
52
62
|
T = TypeVar("T", bound=GrpcMessage)
|
|
@@ -120,11 +130,11 @@ class GrpcAdapter:
|
|
|
120
130
|
"""."""
|
|
121
131
|
return self._send_and_receive(request, DeleteNodeResponse, **kwargs)
|
|
122
132
|
|
|
123
|
-
def
|
|
124
|
-
self, request:
|
|
125
|
-
) ->
|
|
133
|
+
def SendNodeHeartbeat( # pylint: disable=C0103
|
|
134
|
+
self, request: SendNodeHeartbeatRequest, **kwargs: Any
|
|
135
|
+
) -> SendNodeHeartbeatResponse:
|
|
126
136
|
"""."""
|
|
127
|
-
return self._send_and_receive(request,
|
|
137
|
+
return self._send_and_receive(request, SendNodeHeartbeatResponse, **kwargs)
|
|
128
138
|
|
|
129
139
|
def PullMessages( # pylint: disable=C0103
|
|
130
140
|
self, request: PullMessagesRequest, **kwargs: Any
|
|
@@ -149,3 +159,21 @@ class GrpcAdapter:
|
|
|
149
159
|
) -> GetFabResponse:
|
|
150
160
|
"""."""
|
|
151
161
|
return self._send_and_receive(request, GetFabResponse, **kwargs)
|
|
162
|
+
|
|
163
|
+
def PushObject( # pylint: disable=C0103
|
|
164
|
+
self, request: PushObjectRequest, **kwargs: Any
|
|
165
|
+
) -> PushObjectResponse:
|
|
166
|
+
"""."""
|
|
167
|
+
return self._send_and_receive(request, PushObjectResponse, **kwargs)
|
|
168
|
+
|
|
169
|
+
def PullObject( # pylint: disable=C0103
|
|
170
|
+
self, request: PullObjectRequest, **kwargs: Any
|
|
171
|
+
) -> PullObjectResponse:
|
|
172
|
+
"""."""
|
|
173
|
+
return self._send_and_receive(request, PullObjectResponse, **kwargs)
|
|
174
|
+
|
|
175
|
+
def ConfirmMessageReceived( # pylint: disable=C0103
|
|
176
|
+
self, request: ConfirmMessageReceivedRequest, **kwargs: Any
|
|
177
|
+
) -> ConfirmMessageReceivedResponse:
|
|
178
|
+
"""."""
|
|
179
|
+
return self._send_and_receive(request, ConfirmMessageReceivedResponse, **kwargs)
|
|
@@ -164,7 +164,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
|
|
|
164
164
|
in_meta = in_message_metadata
|
|
165
165
|
if ( # pylint: disable-next=too-many-boolean-expressions
|
|
166
166
|
out_meta.run_id == in_meta.run_id
|
|
167
|
-
and out_meta.message_id ==
|
|
167
|
+
and out_meta.message_id == out_message.object_id # Should match the object id
|
|
168
168
|
and out_meta.src_node_id == in_meta.dst_node_id
|
|
169
169
|
and out_meta.dst_node_id == in_meta.src_node_id
|
|
170
170
|
and out_meta.reply_to_message_id == in_meta.message_id
|
flwr/client/mod/comms_mods.py
CHANGED
|
@@ -32,14 +32,17 @@ def message_size_mod(
|
|
|
32
32
|
|
|
33
33
|
This mod logs the size in bytes of the message being transmited.
|
|
34
34
|
"""
|
|
35
|
-
|
|
35
|
+
# Log the size of the incoming message in bytes
|
|
36
|
+
total_bytes = sum(record.count_bytes() for record in msg.content.values())
|
|
37
|
+
log(INFO, "Incoming message size: %i bytes", total_bytes)
|
|
36
38
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
+
# Call the next layer
|
|
40
|
+
msg = call_next(msg, ctxt)
|
|
39
41
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
42
|
+
# Log the size of the outgoing message in bytes
|
|
43
|
+
total_bytes = sum(record.count_bytes() for record in msg.content.values())
|
|
44
|
+
log(INFO, "Outgoing message size: %i bytes", total_bytes)
|
|
45
|
+
return msg
|
|
43
46
|
|
|
44
47
|
|
|
45
48
|
def arrays_size_mod(
|
|
@@ -50,25 +53,41 @@ def arrays_size_mod(
|
|
|
50
53
|
This mod logs the number of array elements transmitted in ``ArrayRecord`` objects
|
|
51
54
|
of the message as well as their sizes in bytes.
|
|
52
55
|
"""
|
|
53
|
-
|
|
54
|
-
|
|
56
|
+
# Log the ArrayRecord size statistics and the total size in the incoming message
|
|
57
|
+
array_record_size_stats = _get_array_record_size_stats(msg)
|
|
58
|
+
total_bytes = sum(stat["bytes"] for stat in array_record_size_stats.values())
|
|
59
|
+
if array_record_size_stats:
|
|
60
|
+
log(INFO, "Incoming `ArrayRecord` size statistics:")
|
|
61
|
+
log(INFO, array_record_size_stats)
|
|
62
|
+
log(INFO, "Total array elements received: %i bytes", total_bytes)
|
|
63
|
+
|
|
64
|
+
msg = call_next(msg, ctxt)
|
|
65
|
+
|
|
66
|
+
# Log the ArrayRecord size statistics and the total size in the outgoing message
|
|
67
|
+
array_record_size_stats = _get_array_record_size_stats(msg)
|
|
68
|
+
total_bytes = sum(stat["bytes"] for stat in array_record_size_stats.values())
|
|
69
|
+
if array_record_size_stats:
|
|
70
|
+
log(INFO, "Outgoing `ArrayRecord` size statistics:")
|
|
71
|
+
log(INFO, array_record_size_stats)
|
|
72
|
+
log(INFO, "Total array elements sent: %i bytes", total_bytes)
|
|
73
|
+
return msg
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _get_array_record_size_stats(
|
|
77
|
+
msg: Message,
|
|
78
|
+
) -> dict[str, dict[str, int]]:
|
|
79
|
+
"""Get `ArrayRecord` size statistics from the message."""
|
|
80
|
+
array_record_size_stats = {}
|
|
55
81
|
for record_name, arr_record in msg.content.array_records.items():
|
|
56
82
|
arr_record_bytes = arr_record.count_bytes()
|
|
57
|
-
arrays_size_in_bytes += arr_record_bytes
|
|
58
83
|
element_count = 0
|
|
59
84
|
for array in arr_record.values():
|
|
60
85
|
element_count += (
|
|
61
86
|
int(np.prod(array.shape)) if array.shape else array.numpy().size
|
|
62
87
|
)
|
|
63
88
|
|
|
64
|
-
|
|
89
|
+
array_record_size_stats[record_name] = {
|
|
65
90
|
"elements": element_count,
|
|
66
91
|
"bytes": arr_record_bytes,
|
|
67
92
|
}
|
|
68
|
-
|
|
69
|
-
if model_size_stats:
|
|
70
|
-
log(INFO, model_size_stats)
|
|
71
|
-
|
|
72
|
-
log(INFO, "Total array elements transmitted: %i bytes", arrays_size_in_bytes)
|
|
73
|
-
|
|
74
|
-
return call_next(msg, ctxt)
|
|
93
|
+
return array_record_size_stats
|