flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +260 -86
- flwr/client/clientapp/app.py +6 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +28 -28
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/run_info_store.py +2 -2
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/__init__.py +12 -4
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/config.py +4 -4
- flwr/common/constant.py +16 -0
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/logger.py +2 -2
- flwr/common/message.py +338 -102
- flwr/common/object_ref.py +0 -10
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +9 -18
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +67 -190
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +44 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +74 -3
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +15 -12
- flwr/server/compat/app_utils.py +26 -18
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +48 -19
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
- flwr/server/run_serverapp.py +6 -17
- flwr/server/server_app.py +126 -33
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +33 -38
- flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
- flwr/server/superlink/linkstate/linkstate.py +51 -64
- flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
- flwr/server/superlink/linkstate/utils.py +171 -133
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -68
- flwr/server/workflow/default_workflows.py +52 -58
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/app.py +0 -14
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +6 -6
- flwr/superexec/exec_user_auth_interceptor.py +22 -4
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/common/record/parametersrecord.py +0 -204
- flwr/common/record/recordset.py +0 -202
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
flwr/server/utils/validator.py
CHANGED
|
@@ -16,93 +16,78 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
|
-
from typing import Union
|
|
20
19
|
|
|
20
|
+
from flwr.common import Message
|
|
21
21
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
|
22
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
23
22
|
|
|
24
23
|
|
|
25
|
-
# pylint: disable-next=too-many-branches
|
|
26
|
-
def
|
|
27
|
-
"""Validate a
|
|
24
|
+
# pylint: disable-next=too-many-branches
|
|
25
|
+
def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
26
|
+
"""Validate a Message."""
|
|
28
27
|
validation_errors = []
|
|
28
|
+
metadata = message.metadata
|
|
29
29
|
|
|
30
|
-
if
|
|
31
|
-
validation_errors.append("non-empty `
|
|
32
|
-
|
|
33
|
-
if not tasks_ins_res.HasField("task"):
|
|
34
|
-
validation_errors.append("`task` does not set field `task`")
|
|
30
|
+
if metadata.message_id != "":
|
|
31
|
+
validation_errors.append("non-empty `metadata.message_id`")
|
|
35
32
|
|
|
36
33
|
# Created/delivered/TTL/Pushed
|
|
37
34
|
if (
|
|
38
|
-
|
|
39
|
-
): # unix timestamp of
|
|
35
|
+
metadata.created_at < 1740700800.0
|
|
36
|
+
): # unix timestamp of 28 February 2025 00h:00m:00s UTC
|
|
40
37
|
validation_errors.append(
|
|
41
|
-
"`created_at` must be a float that records the unix timestamp "
|
|
38
|
+
"`metadata.created_at` must be a float that records the unix timestamp "
|
|
42
39
|
"in seconds when the message was created."
|
|
43
40
|
)
|
|
44
|
-
if
|
|
45
|
-
validation_errors.append("`delivered_at` must be an empty str")
|
|
46
|
-
if
|
|
47
|
-
validation_errors.append("`ttl` must be higher than zero")
|
|
41
|
+
if metadata.delivered_at != "":
|
|
42
|
+
validation_errors.append("`metadata.delivered_at` must be an empty str")
|
|
43
|
+
if metadata.ttl <= 0:
|
|
44
|
+
validation_errors.append("`metadata.ttl` must be higher than zero")
|
|
48
45
|
|
|
49
46
|
# Verify TTL and created_at time
|
|
50
47
|
current_time = time.time()
|
|
51
|
-
if
|
|
52
|
-
validation_errors.append("
|
|
53
|
-
|
|
54
|
-
# TaskIns specific
|
|
55
|
-
if isinstance(tasks_ins_res, TaskIns):
|
|
56
|
-
# Task producer
|
|
57
|
-
if not tasks_ins_res.task.HasField("producer"):
|
|
58
|
-
validation_errors.append("`producer` does not set field `producer`")
|
|
59
|
-
if tasks_ins_res.task.producer.node_id != SUPERLINK_NODE_ID:
|
|
60
|
-
validation_errors.append(f"`producer.node_id` is not {SUPERLINK_NODE_ID}")
|
|
61
|
-
|
|
62
|
-
# Task consumer
|
|
63
|
-
if not tasks_ins_res.task.HasField("consumer"):
|
|
64
|
-
validation_errors.append("`consumer` does not set field `consumer`")
|
|
65
|
-
if tasks_ins_res.task.consumer.node_id == SUPERLINK_NODE_ID:
|
|
66
|
-
validation_errors.append("consumer MUST provide a valid `node_id`")
|
|
48
|
+
if metadata.created_at + metadata.ttl <= current_time:
|
|
49
|
+
validation_errors.append("Message TTL has expired")
|
|
67
50
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
if not (
|
|
72
|
-
tasks_ins_res.task.HasField("recordset")
|
|
73
|
-
^ tasks_ins_res.task.HasField("error")
|
|
74
|
-
):
|
|
75
|
-
validation_errors.append("Either `recordset` or `error` MUST be set")
|
|
51
|
+
# Source node is set and is not zero
|
|
52
|
+
if not metadata.src_node_id:
|
|
53
|
+
validation_errors.append("`metadata.src_node_id` is not set.")
|
|
76
54
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
55
|
+
# Destination node is set and is not zero
|
|
56
|
+
if not metadata.dst_node_id:
|
|
57
|
+
validation_errors.append("`metadata.dst_node_id` is not set.")
|
|
80
58
|
|
|
81
|
-
#
|
|
82
|
-
if
|
|
83
|
-
|
|
84
|
-
if not tasks_ins_res.task.HasField("producer"):
|
|
85
|
-
validation_errors.append("`producer` does not set field `producer`")
|
|
86
|
-
if tasks_ins_res.task.producer.node_id == SUPERLINK_NODE_ID:
|
|
87
|
-
validation_errors.append("producer MUST provide a valid `node_id`")
|
|
59
|
+
# Message type
|
|
60
|
+
if metadata.message_type == "":
|
|
61
|
+
validation_errors.append("`metadata.message_type` MUST be set")
|
|
88
62
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
# Content check
|
|
96
|
-
if tasks_ins_res.task.task_type == "":
|
|
97
|
-
validation_errors.append("`task_type` MUST be set")
|
|
98
|
-
if not (
|
|
99
|
-
tasks_ins_res.task.HasField("recordset")
|
|
100
|
-
^ tasks_ins_res.task.HasField("error")
|
|
101
|
-
):
|
|
102
|
-
validation_errors.append("Either `recordset` or `error` MUST be set")
|
|
63
|
+
# Content
|
|
64
|
+
if not message.has_content() != message.has_error():
|
|
65
|
+
validation_errors.append(
|
|
66
|
+
"Either message `content` or `error` MUST be set (but not both)"
|
|
67
|
+
)
|
|
103
68
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
69
|
+
# Link respose to original message
|
|
70
|
+
if not is_reply_message:
|
|
71
|
+
if metadata.reply_to_message_id != "":
|
|
72
|
+
validation_errors.append("`metadata.reply_to_message_id` MUST not be set.")
|
|
73
|
+
if metadata.src_node_id != SUPERLINK_NODE_ID:
|
|
74
|
+
validation_errors.append(
|
|
75
|
+
f"`metadata.src_node_id` is not {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
|
76
|
+
)
|
|
77
|
+
if metadata.dst_node_id == SUPERLINK_NODE_ID:
|
|
78
|
+
validation_errors.append(
|
|
79
|
+
f"`metadata.dst_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
if metadata.reply_to_message_id == "":
|
|
83
|
+
validation_errors.append("`metadata.reply_to_message_id` MUST be set.")
|
|
84
|
+
if metadata.src_node_id == SUPERLINK_NODE_ID:
|
|
85
|
+
validation_errors.append(
|
|
86
|
+
f"`metadata.src_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
|
87
|
+
)
|
|
88
|
+
if metadata.dst_node_id != SUPERLINK_NODE_ID:
|
|
89
|
+
validation_errors.append(
|
|
90
|
+
f"`metadata.dst_node_id` is not {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
|
91
|
+
)
|
|
107
92
|
|
|
108
93
|
return validation_errors
|
|
@@ -20,15 +20,16 @@ import timeit
|
|
|
20
20
|
from logging import INFO, WARN
|
|
21
21
|
from typing import Optional, Union, cast
|
|
22
22
|
|
|
23
|
-
import flwr.common.
|
|
23
|
+
import flwr.common.recorddict_compat as compat
|
|
24
24
|
from flwr.common import (
|
|
25
|
+
ArrayRecord,
|
|
25
26
|
Code,
|
|
26
|
-
|
|
27
|
+
ConfigRecord,
|
|
27
28
|
Context,
|
|
28
29
|
EvaluateRes,
|
|
29
30
|
FitRes,
|
|
30
31
|
GetParametersIns,
|
|
31
|
-
|
|
32
|
+
Message,
|
|
32
33
|
log,
|
|
33
34
|
)
|
|
34
35
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
@@ -36,7 +37,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
|
36
37
|
from ..client_proxy import ClientProxy
|
|
37
38
|
from ..compat.app_utils import start_update_client_manager_thread
|
|
38
39
|
from ..compat.legacy_context import LegacyContext
|
|
39
|
-
from ..
|
|
40
|
+
from ..grid import Grid
|
|
40
41
|
from ..typing import Workflow
|
|
41
42
|
from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key
|
|
42
43
|
|
|
@@ -56,7 +57,7 @@ class DefaultWorkflow:
|
|
|
56
57
|
self.fit_workflow: Workflow = fit_workflow
|
|
57
58
|
self.evaluate_workflow: Workflow = evaluate_workflow
|
|
58
59
|
|
|
59
|
-
def __call__(self,
|
|
60
|
+
def __call__(self, grid: Grid, context: Context) -> None:
|
|
60
61
|
"""Execute the workflow."""
|
|
61
62
|
if not isinstance(context, LegacyContext):
|
|
62
63
|
raise TypeError(
|
|
@@ -64,19 +65,22 @@ class DefaultWorkflow:
|
|
|
64
65
|
)
|
|
65
66
|
|
|
66
67
|
# Start the thread updating nodes
|
|
67
|
-
thread, f_stop = start_update_client_manager_thread(
|
|
68
|
-
|
|
68
|
+
thread, f_stop, c_done = start_update_client_manager_thread(
|
|
69
|
+
grid, context.client_manager
|
|
69
70
|
)
|
|
70
71
|
|
|
72
|
+
# Wait until the node registration done
|
|
73
|
+
c_done.wait()
|
|
74
|
+
|
|
71
75
|
# Initialize parameters
|
|
72
76
|
log(INFO, "[INIT]")
|
|
73
|
-
default_init_params_workflow(
|
|
77
|
+
default_init_params_workflow(grid, context)
|
|
74
78
|
|
|
75
79
|
# Run federated learning for num_rounds
|
|
76
80
|
start_time = timeit.default_timer()
|
|
77
|
-
cfg =
|
|
81
|
+
cfg = ConfigRecord()
|
|
78
82
|
cfg[Key.START_TIME] = start_time
|
|
79
|
-
context.state.
|
|
83
|
+
context.state.config_records[MAIN_CONFIGS_RECORD] = cfg
|
|
80
84
|
|
|
81
85
|
for current_round in range(1, context.config.num_rounds + 1):
|
|
82
86
|
log(INFO, "")
|
|
@@ -84,13 +88,13 @@ class DefaultWorkflow:
|
|
|
84
88
|
cfg[Key.CURRENT_ROUND] = current_round
|
|
85
89
|
|
|
86
90
|
# Fit round
|
|
87
|
-
self.fit_workflow(
|
|
91
|
+
self.fit_workflow(grid, context)
|
|
88
92
|
|
|
89
93
|
# Centralized evaluation
|
|
90
|
-
default_centralized_evaluation_workflow(
|
|
94
|
+
default_centralized_evaluation_workflow(grid, context)
|
|
91
95
|
|
|
92
96
|
# Evaluate round
|
|
93
|
-
self.evaluate_workflow(
|
|
97
|
+
self.evaluate_workflow(grid, context)
|
|
94
98
|
|
|
95
99
|
# Bookkeeping and log results
|
|
96
100
|
end_time = timeit.default_timer()
|
|
@@ -116,7 +120,7 @@ class DefaultWorkflow:
|
|
|
116
120
|
thread.join()
|
|
117
121
|
|
|
118
122
|
|
|
119
|
-
def default_init_params_workflow(
|
|
123
|
+
def default_init_params_workflow(grid: Grid, context: Context) -> None:
|
|
120
124
|
"""Execute the default workflow for parameters initialization."""
|
|
121
125
|
if not isinstance(context, LegacyContext):
|
|
122
126
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
|
@@ -126,21 +130,19 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
126
130
|
)
|
|
127
131
|
if parameters is not None:
|
|
128
132
|
log(INFO, "Using initial global parameters provided by strategy")
|
|
129
|
-
|
|
130
|
-
parameters, keep_input=True
|
|
131
|
-
)
|
|
133
|
+
arr_record = compat.parameters_to_arrayrecord(parameters, keep_input=True)
|
|
132
134
|
else:
|
|
133
135
|
# Get initial parameters from one of the clients
|
|
134
136
|
log(INFO, "Requesting initial parameters from one random client")
|
|
135
137
|
random_client = context.client_manager.sample(1)[0]
|
|
136
138
|
# Send GetParametersIns and get the response
|
|
137
|
-
content = compat.
|
|
138
|
-
messages =
|
|
139
|
+
content = compat.getparametersins_to_recorddict(GetParametersIns({}))
|
|
140
|
+
messages = grid.send_and_receive(
|
|
139
141
|
[
|
|
140
|
-
|
|
142
|
+
Message(
|
|
141
143
|
content=content,
|
|
142
|
-
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
|
143
144
|
dst_node_id=random_client.node_id,
|
|
145
|
+
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
|
144
146
|
group_id="0",
|
|
145
147
|
)
|
|
146
148
|
]
|
|
@@ -149,26 +151,26 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
149
151
|
|
|
150
152
|
if (
|
|
151
153
|
msg.has_content()
|
|
152
|
-
and compat.
|
|
154
|
+
and compat._extract_status_from_recorddict( # pylint: disable=W0212
|
|
153
155
|
"getparametersres", msg.content
|
|
154
156
|
).code
|
|
155
157
|
== Code.OK
|
|
156
158
|
):
|
|
157
159
|
log(INFO, "Received initial parameters from one random client")
|
|
158
|
-
|
|
160
|
+
arr_record = next(iter(msg.content.array_records.values()))
|
|
159
161
|
else:
|
|
160
162
|
log(
|
|
161
163
|
WARN,
|
|
162
164
|
"Failed to receive initial parameters from the client."
|
|
163
165
|
" Empty initial parameters will be used.",
|
|
164
166
|
)
|
|
165
|
-
|
|
167
|
+
arr_record = ArrayRecord()
|
|
166
168
|
|
|
167
|
-
context.state.
|
|
169
|
+
context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
|
|
168
170
|
|
|
169
171
|
# Evaluate initial parameters
|
|
170
172
|
log(INFO, "Starting evaluation of initial global parameters")
|
|
171
|
-
parameters = compat.
|
|
173
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
|
172
174
|
res = context.strategy.evaluate(0, parameters=parameters)
|
|
173
175
|
if res is not None:
|
|
174
176
|
log(
|
|
@@ -183,19 +185,19 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
183
185
|
log(INFO, "Evaluation returned no results (`None`)")
|
|
184
186
|
|
|
185
187
|
|
|
186
|
-
def default_centralized_evaluation_workflow(_:
|
|
188
|
+
def default_centralized_evaluation_workflow(_: Grid, context: Context) -> None:
|
|
187
189
|
"""Execute the default workflow for centralized evaluation."""
|
|
188
190
|
if not isinstance(context, LegacyContext):
|
|
189
191
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
|
190
192
|
|
|
191
193
|
# Retrieve current_round and start_time from the context
|
|
192
|
-
cfg = context.state.
|
|
194
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
193
195
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
|
194
196
|
start_time = cast(float, cfg[Key.START_TIME])
|
|
195
197
|
|
|
196
198
|
# Centralized evaluation
|
|
197
|
-
parameters = compat.
|
|
198
|
-
record=context.state.
|
|
199
|
+
parameters = compat.arrayrecord_to_parameters(
|
|
200
|
+
record=context.state.array_records[MAIN_PARAMS_RECORD],
|
|
199
201
|
keep_input=True,
|
|
200
202
|
)
|
|
201
203
|
res_cen = context.strategy.evaluate(current_round, parameters=parameters)
|
|
@@ -215,20 +217,16 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
|
|
|
215
217
|
)
|
|
216
218
|
|
|
217
219
|
|
|
218
|
-
def default_fit_workflow( # pylint: disable=R0914
|
|
219
|
-
driver: Driver, context: Context
|
|
220
|
-
) -> None:
|
|
220
|
+
def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disable=R0914
|
|
221
221
|
"""Execute the default workflow for a single fit round."""
|
|
222
222
|
if not isinstance(context, LegacyContext):
|
|
223
223
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
|
224
224
|
|
|
225
225
|
# Get current_round and parameters
|
|
226
|
-
cfg = context.state.
|
|
226
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
227
227
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
|
228
|
-
|
|
229
|
-
parameters = compat.
|
|
230
|
-
parametersrecord, keep_input=True
|
|
231
|
-
)
|
|
228
|
+
arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
|
|
229
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
|
232
230
|
|
|
233
231
|
# Get clients and their respective instructions from strategy
|
|
234
232
|
client_instructions = context.strategy.configure_fit(
|
|
@@ -252,10 +250,10 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
252
250
|
|
|
253
251
|
# Build out messages
|
|
254
252
|
out_messages = [
|
|
255
|
-
|
|
256
|
-
content=compat.
|
|
257
|
-
message_type=MessageType.TRAIN,
|
|
253
|
+
Message(
|
|
254
|
+
content=compat.fitins_to_recorddict(fitins, True),
|
|
258
255
|
dst_node_id=proxy.node_id,
|
|
256
|
+
message_type=MessageType.TRAIN,
|
|
259
257
|
group_id=str(current_round),
|
|
260
258
|
)
|
|
261
259
|
for proxy, fitins in client_instructions
|
|
@@ -263,7 +261,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
263
261
|
|
|
264
262
|
# Send instructions to clients and
|
|
265
263
|
# collect `fit` results from all clients participating in this round
|
|
266
|
-
messages = list(
|
|
264
|
+
messages = list(grid.send_and_receive(out_messages))
|
|
267
265
|
del out_messages
|
|
268
266
|
num_failures = len([msg for msg in messages if msg.has_error()])
|
|
269
267
|
|
|
@@ -281,7 +279,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
281
279
|
for msg in messages:
|
|
282
280
|
if msg.has_content():
|
|
283
281
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
284
|
-
fitres = compat.
|
|
282
|
+
fitres = compat.recorddict_to_fitres(msg.content, False)
|
|
285
283
|
if fitres.status.code == Code.OK:
|
|
286
284
|
results.append((proxy, fitres))
|
|
287
285
|
else:
|
|
@@ -294,28 +292,24 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
294
292
|
|
|
295
293
|
# Update the parameters and write history
|
|
296
294
|
if parameters_aggregated:
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
)
|
|
300
|
-
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
|
295
|
+
arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
|
|
296
|
+
context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
|
|
301
297
|
context.history.add_metrics_distributed_fit(
|
|
302
298
|
server_round=current_round, metrics=metrics_aggregated
|
|
303
299
|
)
|
|
304
300
|
|
|
305
301
|
|
|
306
302
|
# pylint: disable-next=R0914
|
|
307
|
-
def default_evaluate_workflow(
|
|
303
|
+
def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
|
308
304
|
"""Execute the default workflow for a single evaluate round."""
|
|
309
305
|
if not isinstance(context, LegacyContext):
|
|
310
306
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
|
311
307
|
|
|
312
308
|
# Get current_round and parameters
|
|
313
|
-
cfg = context.state.
|
|
309
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
314
310
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
|
315
|
-
|
|
316
|
-
parameters = compat.
|
|
317
|
-
parametersrecord, keep_input=True
|
|
318
|
-
)
|
|
311
|
+
arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
|
|
312
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
|
319
313
|
|
|
320
314
|
# Get clients and their respective instructions from strategy
|
|
321
315
|
client_instructions = context.strategy.configure_evaluate(
|
|
@@ -338,10 +332,10 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
338
332
|
|
|
339
333
|
# Build out messages
|
|
340
334
|
out_messages = [
|
|
341
|
-
|
|
342
|
-
content=compat.
|
|
343
|
-
message_type=MessageType.EVALUATE,
|
|
335
|
+
Message(
|
|
336
|
+
content=compat.evaluateins_to_recorddict(evalins, True),
|
|
344
337
|
dst_node_id=proxy.node_id,
|
|
338
|
+
message_type=MessageType.EVALUATE,
|
|
345
339
|
group_id=str(current_round),
|
|
346
340
|
)
|
|
347
341
|
for proxy, evalins in client_instructions
|
|
@@ -349,7 +343,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
349
343
|
|
|
350
344
|
# Send instructions to clients and
|
|
351
345
|
# collect `evaluate` results from all clients participating in this round
|
|
352
|
-
messages = list(
|
|
346
|
+
messages = list(grid.send_and_receive(out_messages))
|
|
353
347
|
del out_messages
|
|
354
348
|
num_failures = len([msg for msg in messages if msg.has_error()])
|
|
355
349
|
|
|
@@ -367,7 +361,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
367
361
|
for msg in messages:
|
|
368
362
|
if msg.has_content():
|
|
369
363
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
370
|
-
evalres = compat.
|
|
364
|
+
evalres = compat.recorddict_to_evaluateres(msg.content)
|
|
371
365
|
if evalres.status.code == Code.OK:
|
|
372
366
|
results.append((proxy, evalres))
|
|
373
367
|
else:
|