flwr 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +1 -1
- flwr/cli/__init__.py +1 -1
- flwr/cli/app.py +21 -2
- flwr/cli/build.py +1 -1
- flwr/cli/cli_user_auth_interceptor.py +1 -1
- flwr/cli/config_utils.py +53 -17
- flwr/cli/example.py +1 -1
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +1 -1
- flwr/cli/login/__init__.py +1 -1
- flwr/cli/login/login.py +12 -1
- flwr/cli/ls.py +1 -1
- flwr/cli/new/__init__.py +1 -1
- flwr/cli/new/new.py +4 -4
- flwr/cli/new/templates/__init__.py +1 -1
- flwr/cli/new/templates/app/__init__.py +1 -1
- flwr/cli/new/templates/app/code/__init__.py +1 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +1 -1
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +5 -5
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/__init__.py +1 -1
- flwr/cli/run/run.py +6 -10
- flwr/cli/stop.py +1 -1
- flwr/cli/utils.py +11 -12
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +58 -56
- flwr/client/client.py +1 -1
- flwr/client/client_app.py +231 -166
- flwr/client/clientapp/__init__.py +1 -1
- flwr/client/clientapp/app.py +3 -3
- flwr/client/clientapp/clientappio_servicer.py +1 -1
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +1 -1
- flwr/client/grpc_adapter_client/connection.py +1 -1
- flwr/client/grpc_client/__init__.py +1 -1
- flwr/client/grpc_client/connection.py +37 -34
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +1 -1
- flwr/client/grpc_rere_client/connection.py +1 -1
- flwr/client/grpc_rere_client/grpc_adapter.py +1 -1
- flwr/client/heartbeat.py +1 -1
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +28 -28
- flwr/client/mod/__init__.py +3 -3
- flwr/client/mod/centraldp_mods.py +8 -8
- flwr/client/mod/comms_mods.py +17 -23
- flwr/client/mod/localdp_mod.py +10 -10
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secagg_mod.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +32 -32
- flwr/client/mod/utils.py +1 -1
- flwr/client/nodestate/__init__.py +1 -1
- flwr/client/nodestate/in_memory_nodestate.py +1 -1
- flwr/client/nodestate/nodestate.py +1 -1
- flwr/client/nodestate/nodestate_factory.py +1 -1
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/client/run_info_store.py +3 -3
- flwr/client/supernode/__init__.py +1 -1
- flwr/client/supernode/app.py +1 -1
- flwr/client/typing.py +1 -1
- flwr/common/__init__.py +13 -5
- flwr/common/address.py +1 -1
- flwr/common/args.py +1 -1
- flwr/common/auth_plugin/__init__.py +1 -1
- flwr/common/auth_plugin/auth_plugin.py +1 -1
- flwr/common/config.py +5 -5
- flwr/common/constant.py +7 -7
- flwr/common/context.py +5 -5
- flwr/common/date.py +1 -1
- flwr/common/differential_privacy.py +1 -1
- flwr/common/differential_privacy_constants.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit/exit.py +6 -6
- flwr/common/exit_handlers.py +1 -1
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +3 -3
- flwr/common/message.py +344 -102
- flwr/common/object_ref.py +1 -1
- flwr/common/parameter.py +1 -1
- flwr/common/pyproject.py +1 -1
- flwr/common/record/__init__.py +9 -5
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +83 -37
- flwr/common/record/conversion_utils.py +2 -2
- flwr/common/record/{metricsrecord.py → metricrecord.py} +90 -44
- flwr/common/record/recorddict.py +337 -0
- flwr/common/record/typeddict.py +1 -1
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/retry_invoker.py +10 -10
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +52 -30
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +2 -2
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +67 -72
- flwr/common/telemetry.py +2 -2
- flwr/common/typing.py +9 -9
- flwr/common/version.py +1 -1
- flwr/proto/__init__.py +1 -1
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +4 -2
- flwr/server/app.py +67 -12
- flwr/server/client_manager.py +1 -1
- flwr/server/client_proxy.py +1 -1
- flwr/server/compat/__init__.py +3 -3
- flwr/server/compat/app.py +12 -12
- flwr/server/compat/app_utils.py +17 -17
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
- flwr/server/compat/legacy_context.py +1 -1
- flwr/server/criterion.py +1 -1
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +48 -19
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
- flwr/server/history.py +1 -1
- flwr/server/run_serverapp.py +5 -5
- flwr/server/server.py +1 -1
- flwr/server/server_app.py +98 -71
- flwr/server/server_config.py +1 -1
- flwr/server/serverapp/__init__.py +1 -1
- flwr/server/serverapp/app.py +11 -11
- flwr/server/serverapp_components.py +1 -1
- flwr/server/strategy/__init__.py +1 -1
- flwr/server/strategy/aggregate.py +1 -1
- flwr/server/strategy/bulyan.py +2 -2
- flwr/server/strategy/dp_adaptive_clipping.py +17 -17
- flwr/server/strategy/dp_fixed_clipping.py +17 -17
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fault_tolerant_fedavg.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedtrimmedavg.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +3 -2
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/strategy/strategy.py +1 -1
- flwr/server/superlink/__init__.py +1 -1
- flwr/server/superlink/ffs/__init__.py +1 -1
- flwr/server/superlink/ffs/disk_ffs.py +1 -1
- flwr/server/superlink/ffs/ffs.py +1 -1
- flwr/server/superlink/ffs/ffs_factory.py +1 -1
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +13 -13
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -3
- flwr/server/superlink/fleet/vce/vce_api.py +2 -4
- flwr/server/superlink/linkstate/__init__.py +1 -1
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -9
- flwr/server/superlink/linkstate/linkstate.py +5 -5
- flwr/server/superlink/linkstate/linkstate_factory.py +1 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +62 -28
- flwr/server/superlink/linkstate/utils.py +94 -28
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
- flwr/server/superlink/simulation/__init__.py +1 -1
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +3 -3
- flwr/server/superlink/utils.py +1 -1
- flwr/server/typing.py +4 -4
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +5 -5
- flwr/server/workflow/__init__.py +1 -1
- flwr/server/workflow/constant.py +1 -1
- flwr/server/workflow/default_workflows.py +49 -58
- flwr/server/workflow/secure_aggregation/__init__.py +1 -1
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +49 -51
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +3 -3
- flwr/simulation/legacy_app.py +1 -1
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +5 -3
- flwr/simulation/ray_transport/ray_client_proxy.py +35 -33
- flwr/simulation/ray_transport/utils.py +1 -1
- flwr/simulation/run_simulation.py +17 -17
- flwr/simulation/simulationio_connection.py +1 -1
- flwr/superexec/__init__.py +1 -1
- flwr/superexec/app.py +1 -1
- flwr/superexec/deployment.py +5 -5
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +11 -5
- flwr/superexec/exec_servicer.py +3 -3
- flwr/superexec/exec_user_auth_interceptor.py +19 -3
- flwr/superexec/executor.py +4 -4
- flwr/superexec/simulation.py +4 -4
- {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/METADATA +3 -3
- flwr-1.18.0.dist-info/RECORD +332 -0
- flwr/common/record/parametersrecord.py +0 -339
- flwr/common/record/recordset.py +0 -209
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- flwr-1.16.0.dist-info/LICENSE +0 -202
- flwr-1.16.0.dist-info/RECORD +0 -331
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/WHEEL +0 -0
- {flwr-1.16.0.dist-info → flwr-1.18.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -19,21 +19,22 @@ from os import urandom
|
|
|
19
19
|
from typing import Optional
|
|
20
20
|
from uuid import UUID, uuid4
|
|
21
21
|
|
|
22
|
-
from flwr.common import
|
|
23
|
-
from flwr.common.constant import
|
|
22
|
+
from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
|
|
23
|
+
from flwr.common.constant import (
|
|
24
|
+
SUPERLINK_NODE_ID,
|
|
25
|
+
ErrorCode,
|
|
26
|
+
MessageType,
|
|
27
|
+
Status,
|
|
28
|
+
SubStatus,
|
|
29
|
+
)
|
|
30
|
+
from flwr.common.message import make_message
|
|
24
31
|
from flwr.common.typing import RunStatus
|
|
25
32
|
|
|
26
33
|
# pylint: disable=E0611
|
|
27
34
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
28
|
-
from flwr.proto.
|
|
35
|
+
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
|
29
36
|
|
|
30
37
|
# pylint: enable=E0611
|
|
31
|
-
|
|
32
|
-
NODE_UNAVAILABLE_ERROR_REASON = (
|
|
33
|
-
"Error: Node Unavailable - The destination node is currently unavailable. "
|
|
34
|
-
"It exceeds the time limit specified in its last ping."
|
|
35
|
-
)
|
|
36
|
-
|
|
37
38
|
VALID_RUN_STATUS_TRANSITIONS = {
|
|
38
39
|
(Status.PENDING, Status.STARTING),
|
|
39
40
|
(Status.STARTING, Status.RUNNING),
|
|
@@ -54,6 +55,10 @@ MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
|
54
55
|
REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON = (
|
|
55
56
|
"Error: Reply Message Unavailable - The reply message has expired."
|
|
56
57
|
)
|
|
58
|
+
NODE_UNAVAILABLE_ERROR_REASON = (
|
|
59
|
+
"Error: Node Unavailable - The destination node is currently unavailable. "
|
|
60
|
+
"It exceeds twice the time limit specified in its last ping."
|
|
61
|
+
)
|
|
57
62
|
|
|
58
63
|
|
|
59
64
|
def generate_rand_int_from_bytes(
|
|
@@ -167,15 +172,15 @@ def context_from_bytes(context_bytes: bytes) -> Context:
|
|
|
167
172
|
return serde.context_from_proto(ProtoContext.FromString(context_bytes))
|
|
168
173
|
|
|
169
174
|
|
|
170
|
-
def
|
|
171
|
-
"""Serialize a `
|
|
172
|
-
return serde.
|
|
175
|
+
def configrecord_to_bytes(config_record: ConfigRecord) -> bytes:
|
|
176
|
+
"""Serialize a `ConfigRecord` to bytes."""
|
|
177
|
+
return serde.config_record_to_proto(config_record).SerializeToString()
|
|
173
178
|
|
|
174
179
|
|
|
175
|
-
def
|
|
176
|
-
"""Deserialize `
|
|
177
|
-
return serde.
|
|
178
|
-
|
|
180
|
+
def configrecord_from_bytes(configrecord_bytes: bytes) -> ConfigRecord:
|
|
181
|
+
"""Deserialize `ConfigRecord` from bytes."""
|
|
182
|
+
return serde.config_record_from_proto(
|
|
183
|
+
ProtoConfigRecord.FromString(configrecord_bytes)
|
|
179
184
|
)
|
|
180
185
|
|
|
181
186
|
|
|
@@ -231,7 +236,9 @@ def has_valid_sub_status(status: RunStatus) -> bool:
|
|
|
231
236
|
return status.sub_status == ""
|
|
232
237
|
|
|
233
238
|
|
|
234
|
-
def create_message_error_unavailable_res_message(
|
|
239
|
+
def create_message_error_unavailable_res_message(
|
|
240
|
+
ins_metadata: Metadata, error_type: str
|
|
241
|
+
) -> Message:
|
|
235
242
|
"""Generate an error Message that the SuperLink returns carrying the specified
|
|
236
243
|
error."""
|
|
237
244
|
current_time = now().timestamp()
|
|
@@ -241,22 +248,31 @@ def create_message_error_unavailable_res_message(ins_metadata: Metadata) -> Mess
|
|
|
241
248
|
message_id=str(uuid4()),
|
|
242
249
|
src_node_id=SUPERLINK_NODE_ID,
|
|
243
250
|
dst_node_id=SUPERLINK_NODE_ID,
|
|
244
|
-
|
|
251
|
+
reply_to_message_id=ins_metadata.message_id,
|
|
245
252
|
group_id=ins_metadata.group_id,
|
|
246
253
|
message_type=ins_metadata.message_type,
|
|
254
|
+
created_at=current_time,
|
|
247
255
|
ttl=ttl,
|
|
248
256
|
)
|
|
249
257
|
|
|
250
|
-
return
|
|
258
|
+
return make_message(
|
|
251
259
|
metadata=metadata,
|
|
252
260
|
error=Error(
|
|
253
|
-
code=
|
|
254
|
-
|
|
261
|
+
code=(
|
|
262
|
+
ErrorCode.REPLY_MESSAGE_UNAVAILABLE
|
|
263
|
+
if error_type == "msg_unavail"
|
|
264
|
+
else ErrorCode.NODE_UNAVAILABLE
|
|
265
|
+
),
|
|
266
|
+
reason=(
|
|
267
|
+
REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON
|
|
268
|
+
if error_type == "msg_unavail"
|
|
269
|
+
else NODE_UNAVAILABLE_ERROR_REASON
|
|
270
|
+
),
|
|
255
271
|
),
|
|
256
272
|
)
|
|
257
273
|
|
|
258
274
|
|
|
259
|
-
def create_message_error_unavailable_ins_message(
|
|
275
|
+
def create_message_error_unavailable_ins_message(reply_to_message_id: UUID) -> Message:
|
|
260
276
|
"""Error to indicate that the enquired Message had expired before reply arrived or
|
|
261
277
|
that it isn't found."""
|
|
262
278
|
metadata = Metadata(
|
|
@@ -264,13 +280,14 @@ def create_message_error_unavailable_ins_message(reply_to_message: UUID) -> Mess
|
|
|
264
280
|
message_id=str(uuid4()),
|
|
265
281
|
src_node_id=SUPERLINK_NODE_ID,
|
|
266
282
|
dst_node_id=SUPERLINK_NODE_ID,
|
|
267
|
-
|
|
283
|
+
reply_to_message_id=str(reply_to_message_id),
|
|
268
284
|
group_id="", # Unknown
|
|
269
|
-
message_type=
|
|
285
|
+
message_type=MessageType.SYSTEM,
|
|
286
|
+
created_at=now().timestamp(),
|
|
270
287
|
ttl=0,
|
|
271
288
|
)
|
|
272
289
|
|
|
273
|
-
return
|
|
290
|
+
return make_message(
|
|
274
291
|
metadata=metadata,
|
|
275
292
|
error=Error(
|
|
276
293
|
code=ErrorCode.MESSAGE_UNAVAILABLE,
|
|
@@ -358,14 +375,63 @@ def verify_found_message_replies(
|
|
|
358
375
|
ret_dict: dict[UUID, Message] = {}
|
|
359
376
|
current = current_time if current_time else now().timestamp()
|
|
360
377
|
for message_res in found_message_res_list:
|
|
361
|
-
message_ins_id = UUID(message_res.metadata.
|
|
378
|
+
message_ins_id = UUID(message_res.metadata.reply_to_message_id)
|
|
362
379
|
if update_set:
|
|
363
380
|
inquired_message_ids.remove(message_ins_id)
|
|
364
381
|
# Check if the reply Message has expired
|
|
365
382
|
if message_ttl_has_expired(message_res.metadata, current):
|
|
366
383
|
# No need to insert the error Message
|
|
367
384
|
message_res = create_message_error_unavailable_res_message(
|
|
368
|
-
found_message_ins_dict[message_ins_id].metadata
|
|
385
|
+
found_message_ins_dict[message_ins_id].metadata, "msg_unavail"
|
|
369
386
|
)
|
|
370
387
|
ret_dict[message_ins_id] = message_res
|
|
371
388
|
return ret_dict
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def check_node_availability_for_in_message(
|
|
392
|
+
inquired_in_message_ids: set[UUID],
|
|
393
|
+
found_in_message_dict: dict[UUID, Message],
|
|
394
|
+
node_id_to_online_until: dict[int, float],
|
|
395
|
+
current_time: Optional[float] = None,
|
|
396
|
+
update_set: bool = True,
|
|
397
|
+
) -> dict[UUID, Message]:
|
|
398
|
+
"""Check node availability for given Message and generate error reply Message if
|
|
399
|
+
unavailable. A Message error indicating node unavailability will be generated for
|
|
400
|
+
each given Message whose destination node is offline or non-existent.
|
|
401
|
+
|
|
402
|
+
Parameters
|
|
403
|
+
----------
|
|
404
|
+
inquired_in_message_ids : set[UUID]
|
|
405
|
+
Set of Message IDs for which to check destination node availability.
|
|
406
|
+
found_in_message_dict : dict[UUID, Message]
|
|
407
|
+
Dictionary containing all found Message indexed by their IDs.
|
|
408
|
+
node_id_to_online_until : dict[int, float]
|
|
409
|
+
Dictionary mapping node IDs to their online-until timestamps.
|
|
410
|
+
current_time : Optional[float] (default: None)
|
|
411
|
+
The current time to check for expiration. If set to `None`, the current time
|
|
412
|
+
will automatically be set to the current timestamp using `now().timestamp()`.
|
|
413
|
+
update_set : bool (default: True)
|
|
414
|
+
If True, the `inquired_in_message_ids` will be updated to remove invalid ones,
|
|
415
|
+
by default True.
|
|
416
|
+
|
|
417
|
+
Returns
|
|
418
|
+
-------
|
|
419
|
+
dict[UUID, Message]
|
|
420
|
+
A dictionary of error Message indexed by the corresponding Message ID.
|
|
421
|
+
"""
|
|
422
|
+
ret_dict = {}
|
|
423
|
+
current = current_time if current_time else now().timestamp()
|
|
424
|
+
for in_message_id in list(inquired_in_message_ids):
|
|
425
|
+
in_message = found_in_message_dict[in_message_id]
|
|
426
|
+
node_id = in_message.metadata.dst_node_id
|
|
427
|
+
online_until = node_id_to_online_until.get(node_id)
|
|
428
|
+
# Generate a reply message containing an error reply
|
|
429
|
+
# if the node is offline or doesn't exist.
|
|
430
|
+
if online_until is None or online_until < current:
|
|
431
|
+
if update_set:
|
|
432
|
+
inquired_in_message_ids.remove(in_message_id)
|
|
433
|
+
reply_message = create_message_error_unavailable_res_message(
|
|
434
|
+
in_message.metadata, "node_unavail"
|
|
435
|
+
)
|
|
436
|
+
ret_dict[in_message_id] = reply_message
|
|
437
|
+
return ret_dict
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -22,7 +22,7 @@ from uuid import UUID
|
|
|
22
22
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
25
|
-
from flwr.common import
|
|
25
|
+
from flwr.common import ConfigRecord, Message
|
|
26
26
|
from flwr.common.constant import SUPERLINK_NODE_ID, Status
|
|
27
27
|
from flwr.common.logger import log
|
|
28
28
|
from flwr.common.serde import (
|
|
@@ -127,7 +127,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
127
127
|
request.fab_version,
|
|
128
128
|
fab_hash,
|
|
129
129
|
user_config_from_proto(request.override_config),
|
|
130
|
-
|
|
130
|
+
ConfigRecord(),
|
|
131
131
|
)
|
|
132
132
|
return CreateRunResponse(run_id=run_id)
|
|
133
133
|
|
|
@@ -206,7 +206,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
206
206
|
|
|
207
207
|
# Delete the instruction Messages and their replies if found
|
|
208
208
|
message_ins_ids_to_delete = {
|
|
209
|
-
UUID(msg_res.metadata.
|
|
209
|
+
UUID(msg_res.metadata.reply_to_message_id) for msg_res in messages_res
|
|
210
210
|
}
|
|
211
211
|
|
|
212
212
|
state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -24,7 +24,7 @@ from grpc import ServicerContext
|
|
|
24
24
|
from flwr.common.constant import Status
|
|
25
25
|
from flwr.common.logger import log
|
|
26
26
|
from flwr.common.serde import (
|
|
27
|
-
|
|
27
|
+
config_record_to_proto,
|
|
28
28
|
context_from_proto,
|
|
29
29
|
context_to_proto,
|
|
30
30
|
fab_to_proto,
|
|
@@ -182,5 +182,5 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
182
182
|
)
|
|
183
183
|
return GetFederationOptionsResponse()
|
|
184
184
|
return GetFederationOptionsResponse(
|
|
185
|
-
federation_options=
|
|
185
|
+
federation_options=config_record_to_proto(federation_options)
|
|
186
186
|
)
|
flwr/server/superlink/utils.py
CHANGED
flwr/server/typing.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -19,9 +19,9 @@ from typing import Callable
|
|
|
19
19
|
|
|
20
20
|
from flwr.common import Context
|
|
21
21
|
|
|
22
|
-
from .
|
|
22
|
+
from .grid import Grid
|
|
23
23
|
from .serverapp_components import ServerAppComponents
|
|
24
24
|
|
|
25
|
-
ServerAppCallable = Callable[[
|
|
26
|
-
Workflow = Callable[[
|
|
25
|
+
ServerAppCallable = Callable[[Grid, Context], None]
|
|
26
|
+
Workflow = Callable[[Grid, Context], None]
|
|
27
27
|
ServerFn = Callable[[Context], ServerAppComponents]
|
flwr/server/utils/__init__.py
CHANGED
flwr/server/utils/tensorboard.py
CHANGED
flwr/server/utils/validator.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -68,8 +68,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
|
68
68
|
|
|
69
69
|
# Link respose to original message
|
|
70
70
|
if not is_reply_message:
|
|
71
|
-
if metadata.
|
|
72
|
-
validation_errors.append("`metadata.
|
|
71
|
+
if metadata.reply_to_message_id != "":
|
|
72
|
+
validation_errors.append("`metadata.reply_to_message_id` MUST not be set.")
|
|
73
73
|
if metadata.src_node_id != SUPERLINK_NODE_ID:
|
|
74
74
|
validation_errors.append(
|
|
75
75
|
f"`metadata.src_node_id` is not {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
|
@@ -79,8 +79,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
|
79
79
|
f"`metadata.dst_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
|
80
80
|
)
|
|
81
81
|
else:
|
|
82
|
-
if metadata.
|
|
83
|
-
validation_errors.append("`metadata.
|
|
82
|
+
if metadata.reply_to_message_id == "":
|
|
83
|
+
validation_errors.append("`metadata.reply_to_message_id` MUST be set.")
|
|
84
84
|
if metadata.src_node_id == SUPERLINK_NODE_ID:
|
|
85
85
|
validation_errors.append(
|
|
86
86
|
f"`metadata.src_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
flwr/server/workflow/__init__.py
CHANGED
flwr/server/workflow/constant.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -20,15 +20,16 @@ import timeit
|
|
|
20
20
|
from logging import INFO, WARN
|
|
21
21
|
from typing import Optional, Union, cast
|
|
22
22
|
|
|
23
|
-
import flwr.common.
|
|
23
|
+
import flwr.common.recorddict_compat as compat
|
|
24
24
|
from flwr.common import (
|
|
25
|
+
ArrayRecord,
|
|
25
26
|
Code,
|
|
26
|
-
|
|
27
|
+
ConfigRecord,
|
|
27
28
|
Context,
|
|
28
29
|
EvaluateRes,
|
|
29
30
|
FitRes,
|
|
30
31
|
GetParametersIns,
|
|
31
|
-
|
|
32
|
+
Message,
|
|
32
33
|
log,
|
|
33
34
|
)
|
|
34
35
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
@@ -36,7 +37,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
|
36
37
|
from ..client_proxy import ClientProxy
|
|
37
38
|
from ..compat.app_utils import start_update_client_manager_thread
|
|
38
39
|
from ..compat.legacy_context import LegacyContext
|
|
39
|
-
from ..
|
|
40
|
+
from ..grid import Grid
|
|
40
41
|
from ..typing import Workflow
|
|
41
42
|
from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key
|
|
42
43
|
|
|
@@ -56,7 +57,7 @@ class DefaultWorkflow:
|
|
|
56
57
|
self.fit_workflow: Workflow = fit_workflow
|
|
57
58
|
self.evaluate_workflow: Workflow = evaluate_workflow
|
|
58
59
|
|
|
59
|
-
def __call__(self,
|
|
60
|
+
def __call__(self, grid: Grid, context: Context) -> None:
|
|
60
61
|
"""Execute the workflow."""
|
|
61
62
|
if not isinstance(context, LegacyContext):
|
|
62
63
|
raise TypeError(
|
|
@@ -65,7 +66,7 @@ class DefaultWorkflow:
|
|
|
65
66
|
|
|
66
67
|
# Start the thread updating nodes
|
|
67
68
|
thread, f_stop, c_done = start_update_client_manager_thread(
|
|
68
|
-
|
|
69
|
+
grid, context.client_manager
|
|
69
70
|
)
|
|
70
71
|
|
|
71
72
|
# Wait until the node registration done
|
|
@@ -73,13 +74,13 @@ class DefaultWorkflow:
|
|
|
73
74
|
|
|
74
75
|
# Initialize parameters
|
|
75
76
|
log(INFO, "[INIT]")
|
|
76
|
-
default_init_params_workflow(
|
|
77
|
+
default_init_params_workflow(grid, context)
|
|
77
78
|
|
|
78
79
|
# Run federated learning for num_rounds
|
|
79
80
|
start_time = timeit.default_timer()
|
|
80
|
-
cfg =
|
|
81
|
+
cfg = ConfigRecord()
|
|
81
82
|
cfg[Key.START_TIME] = start_time
|
|
82
|
-
context.state.
|
|
83
|
+
context.state.config_records[MAIN_CONFIGS_RECORD] = cfg
|
|
83
84
|
|
|
84
85
|
for current_round in range(1, context.config.num_rounds + 1):
|
|
85
86
|
log(INFO, "")
|
|
@@ -87,13 +88,13 @@ class DefaultWorkflow:
|
|
|
87
88
|
cfg[Key.CURRENT_ROUND] = current_round
|
|
88
89
|
|
|
89
90
|
# Fit round
|
|
90
|
-
self.fit_workflow(
|
|
91
|
+
self.fit_workflow(grid, context)
|
|
91
92
|
|
|
92
93
|
# Centralized evaluation
|
|
93
|
-
default_centralized_evaluation_workflow(
|
|
94
|
+
default_centralized_evaluation_workflow(grid, context)
|
|
94
95
|
|
|
95
96
|
# Evaluate round
|
|
96
|
-
self.evaluate_workflow(
|
|
97
|
+
self.evaluate_workflow(grid, context)
|
|
97
98
|
|
|
98
99
|
# Bookkeeping and log results
|
|
99
100
|
end_time = timeit.default_timer()
|
|
@@ -119,7 +120,7 @@ class DefaultWorkflow:
|
|
|
119
120
|
thread.join()
|
|
120
121
|
|
|
121
122
|
|
|
122
|
-
def default_init_params_workflow(
|
|
123
|
+
def default_init_params_workflow(grid: Grid, context: Context) -> None:
|
|
123
124
|
"""Execute the default workflow for parameters initialization."""
|
|
124
125
|
if not isinstance(context, LegacyContext):
|
|
125
126
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
|
@@ -129,21 +130,19 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
129
130
|
)
|
|
130
131
|
if parameters is not None:
|
|
131
132
|
log(INFO, "Using initial global parameters provided by strategy")
|
|
132
|
-
|
|
133
|
-
parameters, keep_input=True
|
|
134
|
-
)
|
|
133
|
+
arr_record = compat.parameters_to_arrayrecord(parameters, keep_input=True)
|
|
135
134
|
else:
|
|
136
135
|
# Get initial parameters from one of the clients
|
|
137
136
|
log(INFO, "Requesting initial parameters from one random client")
|
|
138
137
|
random_client = context.client_manager.sample(1)[0]
|
|
139
138
|
# Send GetParametersIns and get the response
|
|
140
|
-
content = compat.
|
|
141
|
-
messages =
|
|
139
|
+
content = compat.getparametersins_to_recorddict(GetParametersIns({}))
|
|
140
|
+
messages = grid.send_and_receive(
|
|
142
141
|
[
|
|
143
|
-
|
|
142
|
+
Message(
|
|
144
143
|
content=content,
|
|
145
|
-
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
|
146
144
|
dst_node_id=random_client.node_id,
|
|
145
|
+
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
|
147
146
|
group_id="0",
|
|
148
147
|
)
|
|
149
148
|
]
|
|
@@ -152,26 +151,26 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
152
151
|
|
|
153
152
|
if (
|
|
154
153
|
msg.has_content()
|
|
155
|
-
and compat.
|
|
154
|
+
and compat._extract_status_from_recorddict( # pylint: disable=W0212
|
|
156
155
|
"getparametersres", msg.content
|
|
157
156
|
).code
|
|
158
157
|
== Code.OK
|
|
159
158
|
):
|
|
160
159
|
log(INFO, "Received initial parameters from one random client")
|
|
161
|
-
|
|
160
|
+
arr_record = next(iter(msg.content.array_records.values()))
|
|
162
161
|
else:
|
|
163
162
|
log(
|
|
164
163
|
WARN,
|
|
165
164
|
"Failed to receive initial parameters from the client."
|
|
166
165
|
" Empty initial parameters will be used.",
|
|
167
166
|
)
|
|
168
|
-
|
|
167
|
+
arr_record = ArrayRecord()
|
|
169
168
|
|
|
170
|
-
context.state.
|
|
169
|
+
context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
|
|
171
170
|
|
|
172
171
|
# Evaluate initial parameters
|
|
173
172
|
log(INFO, "Starting evaluation of initial global parameters")
|
|
174
|
-
parameters = compat.
|
|
173
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
|
175
174
|
res = context.strategy.evaluate(0, parameters=parameters)
|
|
176
175
|
if res is not None:
|
|
177
176
|
log(
|
|
@@ -186,19 +185,19 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
186
185
|
log(INFO, "Evaluation returned no results (`None`)")
|
|
187
186
|
|
|
188
187
|
|
|
189
|
-
def default_centralized_evaluation_workflow(_:
|
|
188
|
+
def default_centralized_evaluation_workflow(_: Grid, context: Context) -> None:
|
|
190
189
|
"""Execute the default workflow for centralized evaluation."""
|
|
191
190
|
if not isinstance(context, LegacyContext):
|
|
192
191
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
|
193
192
|
|
|
194
193
|
# Retrieve current_round and start_time from the context
|
|
195
|
-
cfg = context.state.
|
|
194
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
196
195
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
|
197
196
|
start_time = cast(float, cfg[Key.START_TIME])
|
|
198
197
|
|
|
199
198
|
# Centralized evaluation
|
|
200
|
-
parameters = compat.
|
|
201
|
-
record=context.state.
|
|
199
|
+
parameters = compat.arrayrecord_to_parameters(
|
|
200
|
+
record=context.state.array_records[MAIN_PARAMS_RECORD],
|
|
202
201
|
keep_input=True,
|
|
203
202
|
)
|
|
204
203
|
res_cen = context.strategy.evaluate(current_round, parameters=parameters)
|
|
@@ -218,20 +217,16 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
|
|
|
218
217
|
)
|
|
219
218
|
|
|
220
219
|
|
|
221
|
-
def default_fit_workflow( # pylint: disable=R0914
|
|
222
|
-
driver: Driver, context: Context
|
|
223
|
-
) -> None:
|
|
220
|
+
def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disable=R0914
|
|
224
221
|
"""Execute the default workflow for a single fit round."""
|
|
225
222
|
if not isinstance(context, LegacyContext):
|
|
226
223
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
|
227
224
|
|
|
228
225
|
# Get current_round and parameters
|
|
229
|
-
cfg = context.state.
|
|
226
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
230
227
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
|
231
|
-
|
|
232
|
-
parameters = compat.
|
|
233
|
-
parametersrecord, keep_input=True
|
|
234
|
-
)
|
|
228
|
+
arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
|
|
229
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
|
235
230
|
|
|
236
231
|
# Get clients and their respective instructions from strategy
|
|
237
232
|
client_instructions = context.strategy.configure_fit(
|
|
@@ -255,10 +250,10 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
255
250
|
|
|
256
251
|
# Build out messages
|
|
257
252
|
out_messages = [
|
|
258
|
-
|
|
259
|
-
content=compat.
|
|
260
|
-
message_type=MessageType.TRAIN,
|
|
253
|
+
Message(
|
|
254
|
+
content=compat.fitins_to_recorddict(fitins, True),
|
|
261
255
|
dst_node_id=proxy.node_id,
|
|
256
|
+
message_type=MessageType.TRAIN,
|
|
262
257
|
group_id=str(current_round),
|
|
263
258
|
)
|
|
264
259
|
for proxy, fitins in client_instructions
|
|
@@ -266,7 +261,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
266
261
|
|
|
267
262
|
# Send instructions to clients and
|
|
268
263
|
# collect `fit` results from all clients participating in this round
|
|
269
|
-
messages = list(
|
|
264
|
+
messages = list(grid.send_and_receive(out_messages))
|
|
270
265
|
del out_messages
|
|
271
266
|
num_failures = len([msg for msg in messages if msg.has_error()])
|
|
272
267
|
|
|
@@ -284,7 +279,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
284
279
|
for msg in messages:
|
|
285
280
|
if msg.has_content():
|
|
286
281
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
287
|
-
fitres = compat.
|
|
282
|
+
fitres = compat.recorddict_to_fitres(msg.content, False)
|
|
288
283
|
if fitres.status.code == Code.OK:
|
|
289
284
|
results.append((proxy, fitres))
|
|
290
285
|
else:
|
|
@@ -297,28 +292,24 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
297
292
|
|
|
298
293
|
# Update the parameters and write history
|
|
299
294
|
if parameters_aggregated:
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
)
|
|
303
|
-
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
|
295
|
+
arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
|
|
296
|
+
context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
|
|
304
297
|
context.history.add_metrics_distributed_fit(
|
|
305
298
|
server_round=current_round, metrics=metrics_aggregated
|
|
306
299
|
)
|
|
307
300
|
|
|
308
301
|
|
|
309
302
|
# pylint: disable-next=R0914
|
|
310
|
-
def default_evaluate_workflow(
|
|
303
|
+
def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
|
311
304
|
"""Execute the default workflow for a single evaluate round."""
|
|
312
305
|
if not isinstance(context, LegacyContext):
|
|
313
306
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
|
314
307
|
|
|
315
308
|
# Get current_round and parameters
|
|
316
|
-
cfg = context.state.
|
|
309
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
317
310
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
|
318
|
-
|
|
319
|
-
parameters = compat.
|
|
320
|
-
parametersrecord, keep_input=True
|
|
321
|
-
)
|
|
311
|
+
arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
|
|
312
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
|
322
313
|
|
|
323
314
|
# Get clients and their respective instructions from strategy
|
|
324
315
|
client_instructions = context.strategy.configure_evaluate(
|
|
@@ -341,10 +332,10 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
341
332
|
|
|
342
333
|
# Build out messages
|
|
343
334
|
out_messages = [
|
|
344
|
-
|
|
345
|
-
content=compat.
|
|
346
|
-
message_type=MessageType.EVALUATE,
|
|
335
|
+
Message(
|
|
336
|
+
content=compat.evaluateins_to_recorddict(evalins, True),
|
|
347
337
|
dst_node_id=proxy.node_id,
|
|
338
|
+
message_type=MessageType.EVALUATE,
|
|
348
339
|
group_id=str(current_round),
|
|
349
340
|
)
|
|
350
341
|
for proxy, evalins in client_instructions
|
|
@@ -352,7 +343,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
352
343
|
|
|
353
344
|
# Send instructions to clients and
|
|
354
345
|
# collect `evaluate` results from all clients participating in this round
|
|
355
|
-
messages = list(
|
|
346
|
+
messages = list(grid.send_and_receive(out_messages))
|
|
356
347
|
del out_messages
|
|
357
348
|
num_failures = len([msg for msg in messages if msg.has_error()])
|
|
358
349
|
|
|
@@ -370,7 +361,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
370
361
|
for msg in messages:
|
|
371
362
|
if msg.has_content():
|
|
372
363
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
373
|
-
evalres = compat.
|
|
364
|
+
evalres = compat.recorddict_to_evaluateres(msg.content)
|
|
374
365
|
if evalres.status.code == Code.OK:
|
|
375
366
|
results.append((proxy, evalres))
|
|
376
367
|
else:
|