flwr 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +162 -99
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +12 -4
- flwr/common/config.py +4 -4
- flwr/common/constant.py +6 -6
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/logger.py +2 -2
- flwr/common/message.py +327 -102
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +66 -71
- flwr/common/typing.py +8 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +56 -1
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +11 -11
- flwr/server/compat/app_utils.py +16 -16
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +47 -18
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
- flwr/server/run_serverapp.py +4 -4
- flwr/server/server_app.py +38 -18
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +33 -8
- flwr/server/superlink/linkstate/linkstate.py +4 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +61 -27
- flwr/server/superlink/linkstate/utils.py +93 -27
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +48 -57
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +2 -2
- flwr/superexec/exec_user_auth_interceptor.py +18 -2
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/METADATA +2 -2
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/RECORD +94 -92
- flwr/common/record/parametersrecord.py +0 -339
- flwr/common/record/recordset.py +0 -209
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -0
|
@@ -20,15 +20,16 @@ import timeit
|
|
|
20
20
|
from logging import INFO, WARN
|
|
21
21
|
from typing import Optional, Union, cast
|
|
22
22
|
|
|
23
|
-
import flwr.common.
|
|
23
|
+
import flwr.common.recorddict_compat as compat
|
|
24
24
|
from flwr.common import (
|
|
25
|
+
ArrayRecord,
|
|
25
26
|
Code,
|
|
26
|
-
|
|
27
|
+
ConfigRecord,
|
|
27
28
|
Context,
|
|
28
29
|
EvaluateRes,
|
|
29
30
|
FitRes,
|
|
30
31
|
GetParametersIns,
|
|
31
|
-
|
|
32
|
+
Message,
|
|
32
33
|
log,
|
|
33
34
|
)
|
|
34
35
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
@@ -36,7 +37,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
|
36
37
|
from ..client_proxy import ClientProxy
|
|
37
38
|
from ..compat.app_utils import start_update_client_manager_thread
|
|
38
39
|
from ..compat.legacy_context import LegacyContext
|
|
39
|
-
from ..
|
|
40
|
+
from ..grid import Grid
|
|
40
41
|
from ..typing import Workflow
|
|
41
42
|
from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key
|
|
42
43
|
|
|
@@ -56,7 +57,7 @@ class DefaultWorkflow:
|
|
|
56
57
|
self.fit_workflow: Workflow = fit_workflow
|
|
57
58
|
self.evaluate_workflow: Workflow = evaluate_workflow
|
|
58
59
|
|
|
59
|
-
def __call__(self,
|
|
60
|
+
def __call__(self, grid: Grid, context: Context) -> None:
|
|
60
61
|
"""Execute the workflow."""
|
|
61
62
|
if not isinstance(context, LegacyContext):
|
|
62
63
|
raise TypeError(
|
|
@@ -65,7 +66,7 @@ class DefaultWorkflow:
|
|
|
65
66
|
|
|
66
67
|
# Start the thread updating nodes
|
|
67
68
|
thread, f_stop, c_done = start_update_client_manager_thread(
|
|
68
|
-
|
|
69
|
+
grid, context.client_manager
|
|
69
70
|
)
|
|
70
71
|
|
|
71
72
|
# Wait until the node registration done
|
|
@@ -73,13 +74,13 @@ class DefaultWorkflow:
|
|
|
73
74
|
|
|
74
75
|
# Initialize parameters
|
|
75
76
|
log(INFO, "[INIT]")
|
|
76
|
-
default_init_params_workflow(
|
|
77
|
+
default_init_params_workflow(grid, context)
|
|
77
78
|
|
|
78
79
|
# Run federated learning for num_rounds
|
|
79
80
|
start_time = timeit.default_timer()
|
|
80
|
-
cfg =
|
|
81
|
+
cfg = ConfigRecord()
|
|
81
82
|
cfg[Key.START_TIME] = start_time
|
|
82
|
-
context.state.
|
|
83
|
+
context.state.config_records[MAIN_CONFIGS_RECORD] = cfg
|
|
83
84
|
|
|
84
85
|
for current_round in range(1, context.config.num_rounds + 1):
|
|
85
86
|
log(INFO, "")
|
|
@@ -87,13 +88,13 @@ class DefaultWorkflow:
|
|
|
87
88
|
cfg[Key.CURRENT_ROUND] = current_round
|
|
88
89
|
|
|
89
90
|
# Fit round
|
|
90
|
-
self.fit_workflow(
|
|
91
|
+
self.fit_workflow(grid, context)
|
|
91
92
|
|
|
92
93
|
# Centralized evaluation
|
|
93
|
-
default_centralized_evaluation_workflow(
|
|
94
|
+
default_centralized_evaluation_workflow(grid, context)
|
|
94
95
|
|
|
95
96
|
# Evaluate round
|
|
96
|
-
self.evaluate_workflow(
|
|
97
|
+
self.evaluate_workflow(grid, context)
|
|
97
98
|
|
|
98
99
|
# Bookkeeping and log results
|
|
99
100
|
end_time = timeit.default_timer()
|
|
@@ -119,7 +120,7 @@ class DefaultWorkflow:
|
|
|
119
120
|
thread.join()
|
|
120
121
|
|
|
121
122
|
|
|
122
|
-
def default_init_params_workflow(
|
|
123
|
+
def default_init_params_workflow(grid: Grid, context: Context) -> None:
|
|
123
124
|
"""Execute the default workflow for parameters initialization."""
|
|
124
125
|
if not isinstance(context, LegacyContext):
|
|
125
126
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
|
@@ -129,21 +130,19 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
129
130
|
)
|
|
130
131
|
if parameters is not None:
|
|
131
132
|
log(INFO, "Using initial global parameters provided by strategy")
|
|
132
|
-
|
|
133
|
-
parameters, keep_input=True
|
|
134
|
-
)
|
|
133
|
+
arr_record = compat.parameters_to_arrayrecord(parameters, keep_input=True)
|
|
135
134
|
else:
|
|
136
135
|
# Get initial parameters from one of the clients
|
|
137
136
|
log(INFO, "Requesting initial parameters from one random client")
|
|
138
137
|
random_client = context.client_manager.sample(1)[0]
|
|
139
138
|
# Send GetParametersIns and get the response
|
|
140
|
-
content = compat.
|
|
141
|
-
messages =
|
|
139
|
+
content = compat.getparametersins_to_recorddict(GetParametersIns({}))
|
|
140
|
+
messages = grid.send_and_receive(
|
|
142
141
|
[
|
|
143
|
-
|
|
142
|
+
Message(
|
|
144
143
|
content=content,
|
|
145
|
-
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
|
146
144
|
dst_node_id=random_client.node_id,
|
|
145
|
+
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
|
147
146
|
group_id="0",
|
|
148
147
|
)
|
|
149
148
|
]
|
|
@@ -152,26 +151,26 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
152
151
|
|
|
153
152
|
if (
|
|
154
153
|
msg.has_content()
|
|
155
|
-
and compat.
|
|
154
|
+
and compat._extract_status_from_recorddict( # pylint: disable=W0212
|
|
156
155
|
"getparametersres", msg.content
|
|
157
156
|
).code
|
|
158
157
|
== Code.OK
|
|
159
158
|
):
|
|
160
159
|
log(INFO, "Received initial parameters from one random client")
|
|
161
|
-
|
|
160
|
+
arr_record = next(iter(msg.content.array_records.values()))
|
|
162
161
|
else:
|
|
163
162
|
log(
|
|
164
163
|
WARN,
|
|
165
164
|
"Failed to receive initial parameters from the client."
|
|
166
165
|
" Empty initial parameters will be used.",
|
|
167
166
|
)
|
|
168
|
-
|
|
167
|
+
arr_record = ArrayRecord()
|
|
169
168
|
|
|
170
|
-
context.state.
|
|
169
|
+
context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
|
|
171
170
|
|
|
172
171
|
# Evaluate initial parameters
|
|
173
172
|
log(INFO, "Starting evaluation of initial global parameters")
|
|
174
|
-
parameters = compat.
|
|
173
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
|
175
174
|
res = context.strategy.evaluate(0, parameters=parameters)
|
|
176
175
|
if res is not None:
|
|
177
176
|
log(
|
|
@@ -186,19 +185,19 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
186
185
|
log(INFO, "Evaluation returned no results (`None`)")
|
|
187
186
|
|
|
188
187
|
|
|
189
|
-
def default_centralized_evaluation_workflow(_:
|
|
188
|
+
def default_centralized_evaluation_workflow(_: Grid, context: Context) -> None:
|
|
190
189
|
"""Execute the default workflow for centralized evaluation."""
|
|
191
190
|
if not isinstance(context, LegacyContext):
|
|
192
191
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
|
193
192
|
|
|
194
193
|
# Retrieve current_round and start_time from the context
|
|
195
|
-
cfg = context.state.
|
|
194
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
196
195
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
|
197
196
|
start_time = cast(float, cfg[Key.START_TIME])
|
|
198
197
|
|
|
199
198
|
# Centralized evaluation
|
|
200
|
-
parameters = compat.
|
|
201
|
-
record=context.state.
|
|
199
|
+
parameters = compat.arrayrecord_to_parameters(
|
|
200
|
+
record=context.state.array_records[MAIN_PARAMS_RECORD],
|
|
202
201
|
keep_input=True,
|
|
203
202
|
)
|
|
204
203
|
res_cen = context.strategy.evaluate(current_round, parameters=parameters)
|
|
@@ -218,20 +217,16 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
|
|
|
218
217
|
)
|
|
219
218
|
|
|
220
219
|
|
|
221
|
-
def default_fit_workflow( # pylint: disable=R0914
|
|
222
|
-
driver: Driver, context: Context
|
|
223
|
-
) -> None:
|
|
220
|
+
def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disable=R0914
|
|
224
221
|
"""Execute the default workflow for a single fit round."""
|
|
225
222
|
if not isinstance(context, LegacyContext):
|
|
226
223
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
|
227
224
|
|
|
228
225
|
# Get current_round and parameters
|
|
229
|
-
cfg = context.state.
|
|
226
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
230
227
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
|
231
|
-
|
|
232
|
-
parameters = compat.
|
|
233
|
-
parametersrecord, keep_input=True
|
|
234
|
-
)
|
|
228
|
+
arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
|
|
229
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
|
235
230
|
|
|
236
231
|
# Get clients and their respective instructions from strategy
|
|
237
232
|
client_instructions = context.strategy.configure_fit(
|
|
@@ -255,10 +250,10 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
255
250
|
|
|
256
251
|
# Build out messages
|
|
257
252
|
out_messages = [
|
|
258
|
-
|
|
259
|
-
content=compat.
|
|
260
|
-
message_type=MessageType.TRAIN,
|
|
253
|
+
Message(
|
|
254
|
+
content=compat.fitins_to_recorddict(fitins, True),
|
|
261
255
|
dst_node_id=proxy.node_id,
|
|
256
|
+
message_type=MessageType.TRAIN,
|
|
262
257
|
group_id=str(current_round),
|
|
263
258
|
)
|
|
264
259
|
for proxy, fitins in client_instructions
|
|
@@ -266,7 +261,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
266
261
|
|
|
267
262
|
# Send instructions to clients and
|
|
268
263
|
# collect `fit` results from all clients participating in this round
|
|
269
|
-
messages = list(
|
|
264
|
+
messages = list(grid.send_and_receive(out_messages))
|
|
270
265
|
del out_messages
|
|
271
266
|
num_failures = len([msg for msg in messages if msg.has_error()])
|
|
272
267
|
|
|
@@ -284,7 +279,7 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
284
279
|
for msg in messages:
|
|
285
280
|
if msg.has_content():
|
|
286
281
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
287
|
-
fitres = compat.
|
|
282
|
+
fitres = compat.recorddict_to_fitres(msg.content, False)
|
|
288
283
|
if fitres.status.code == Code.OK:
|
|
289
284
|
results.append((proxy, fitres))
|
|
290
285
|
else:
|
|
@@ -297,28 +292,24 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
297
292
|
|
|
298
293
|
# Update the parameters and write history
|
|
299
294
|
if parameters_aggregated:
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
)
|
|
303
|
-
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
|
295
|
+
arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
|
|
296
|
+
context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
|
|
304
297
|
context.history.add_metrics_distributed_fit(
|
|
305
298
|
server_round=current_round, metrics=metrics_aggregated
|
|
306
299
|
)
|
|
307
300
|
|
|
308
301
|
|
|
309
302
|
# pylint: disable-next=R0914
|
|
310
|
-
def default_evaluate_workflow(
|
|
303
|
+
def default_evaluate_workflow(grid: Grid, context: Context) -> None:
|
|
311
304
|
"""Execute the default workflow for a single evaluate round."""
|
|
312
305
|
if not isinstance(context, LegacyContext):
|
|
313
306
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
|
314
307
|
|
|
315
308
|
# Get current_round and parameters
|
|
316
|
-
cfg = context.state.
|
|
309
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
317
310
|
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
|
318
|
-
|
|
319
|
-
parameters = compat.
|
|
320
|
-
parametersrecord, keep_input=True
|
|
321
|
-
)
|
|
311
|
+
arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
|
|
312
|
+
parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
|
|
322
313
|
|
|
323
314
|
# Get clients and their respective instructions from strategy
|
|
324
315
|
client_instructions = context.strategy.configure_evaluate(
|
|
@@ -341,10 +332,10 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
341
332
|
|
|
342
333
|
# Build out messages
|
|
343
334
|
out_messages = [
|
|
344
|
-
|
|
345
|
-
content=compat.
|
|
346
|
-
message_type=MessageType.EVALUATE,
|
|
335
|
+
Message(
|
|
336
|
+
content=compat.evaluateins_to_recorddict(evalins, True),
|
|
347
337
|
dst_node_id=proxy.node_id,
|
|
338
|
+
message_type=MessageType.EVALUATE,
|
|
348
339
|
group_id=str(current_round),
|
|
349
340
|
)
|
|
350
341
|
for proxy, evalins in client_instructions
|
|
@@ -352,7 +343,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
352
343
|
|
|
353
344
|
# Send instructions to clients and
|
|
354
345
|
# collect `evaluate` results from all clients participating in this round
|
|
355
|
-
messages = list(
|
|
346
|
+
messages = list(grid.send_and_receive(out_messages))
|
|
356
347
|
del out_messages
|
|
357
348
|
num_failures = len([msg for msg in messages if msg.has_error()])
|
|
358
349
|
|
|
@@ -370,7 +361,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
370
361
|
for msg in messages:
|
|
371
362
|
if msg.has_content():
|
|
372
363
|
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
373
|
-
evalres = compat.
|
|
364
|
+
evalres = compat.recorddict_to_evaluateres(msg.content)
|
|
374
365
|
if evalres.status.code == Code.OK:
|
|
375
366
|
results.append((proxy, evalres))
|
|
376
367
|
else:
|
|
@@ -20,15 +20,15 @@ from dataclasses import dataclass, field
|
|
|
20
20
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
21
21
|
from typing import Optional, Union, cast
|
|
22
22
|
|
|
23
|
-
import flwr.common.
|
|
23
|
+
import flwr.common.recorddict_compat as compat
|
|
24
24
|
from flwr.common import (
|
|
25
|
-
|
|
25
|
+
ConfigRecord,
|
|
26
26
|
Context,
|
|
27
27
|
FitRes,
|
|
28
28
|
Message,
|
|
29
29
|
MessageType,
|
|
30
30
|
NDArrays,
|
|
31
|
-
|
|
31
|
+
RecordDict,
|
|
32
32
|
bytes_to_ndarray,
|
|
33
33
|
log,
|
|
34
34
|
ndarrays_to_parameters,
|
|
@@ -55,7 +55,7 @@ from flwr.common.secure_aggregation.secaggplus_constants import (
|
|
|
55
55
|
from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
|
|
56
56
|
from flwr.server.client_proxy import ClientProxy
|
|
57
57
|
from flwr.server.compat.legacy_context import LegacyContext
|
|
58
|
-
from flwr.server.
|
|
58
|
+
from flwr.server.grid import Grid
|
|
59
59
|
|
|
60
60
|
from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
|
|
61
61
|
from ..constant import Key as WorkflowKey
|
|
@@ -66,7 +66,7 @@ class WorkflowState: # pylint: disable=R0902
|
|
|
66
66
|
"""The state of the SecAgg+ protocol."""
|
|
67
67
|
|
|
68
68
|
nid_to_proxies: dict[int, ClientProxy] = field(default_factory=dict)
|
|
69
|
-
nid_to_fitins: dict[int,
|
|
69
|
+
nid_to_fitins: dict[int, RecordDict] = field(default_factory=dict)
|
|
70
70
|
sampled_node_ids: set[int] = field(default_factory=set)
|
|
71
71
|
active_node_ids: set[int] = field(default_factory=set)
|
|
72
72
|
num_shares: int = 0
|
|
@@ -186,7 +186,7 @@ class SecAggPlusWorkflow:
|
|
|
186
186
|
|
|
187
187
|
self._check_init_params()
|
|
188
188
|
|
|
189
|
-
def __call__(self,
|
|
189
|
+
def __call__(self, grid: Grid, context: Context) -> None:
|
|
190
190
|
"""Run the SecAgg+ protocol."""
|
|
191
191
|
if not isinstance(context, LegacyContext):
|
|
192
192
|
raise TypeError(
|
|
@@ -202,7 +202,7 @@ class SecAggPlusWorkflow:
|
|
|
202
202
|
)
|
|
203
203
|
log(INFO, "Secure aggregation commencing.")
|
|
204
204
|
for step in steps:
|
|
205
|
-
if not step(
|
|
205
|
+
if not step(grid, context, state):
|
|
206
206
|
log(INFO, "Secure aggregation halted.")
|
|
207
207
|
return
|
|
208
208
|
log(INFO, "Secure aggregation completed.")
|
|
@@ -279,14 +279,14 @@ class SecAggPlusWorkflow:
|
|
|
279
279
|
return True
|
|
280
280
|
|
|
281
281
|
def setup_stage( # pylint: disable=R0912, R0914, R0915
|
|
282
|
-
self,
|
|
282
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
|
283
283
|
) -> bool:
|
|
284
284
|
"""Execute the 'setup' stage."""
|
|
285
285
|
# Obtain fit instructions
|
|
286
|
-
cfg = context.state.
|
|
286
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
287
287
|
current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
|
|
288
|
-
parameters = compat.
|
|
289
|
-
context.state.
|
|
288
|
+
parameters = compat.arrayrecord_to_parameters(
|
|
289
|
+
context.state.array_records[MAIN_PARAMS_RECORD],
|
|
290
290
|
keep_input=True,
|
|
291
291
|
)
|
|
292
292
|
proxy_fitins_lst = context.strategy.configure_fit(
|
|
@@ -303,7 +303,7 @@ class SecAggPlusWorkflow:
|
|
|
303
303
|
)
|
|
304
304
|
|
|
305
305
|
state.nid_to_fitins = {
|
|
306
|
-
proxy.node_id: compat.
|
|
306
|
+
proxy.node_id: compat.fitins_to_recorddict(fitins, True)
|
|
307
307
|
for proxy, fitins in proxy_fitins_lst
|
|
308
308
|
}
|
|
309
309
|
state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst}
|
|
@@ -366,14 +366,14 @@ class SecAggPlusWorkflow:
|
|
|
366
366
|
state.sampled_node_ids = state.active_node_ids
|
|
367
367
|
|
|
368
368
|
# Send setup configuration to clients
|
|
369
|
-
|
|
370
|
-
content =
|
|
369
|
+
cfg_record = ConfigRecord(sa_params_dict) # type: ignore
|
|
370
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
|
|
371
371
|
|
|
372
372
|
def make(nid: int) -> Message:
|
|
373
|
-
return
|
|
373
|
+
return Message(
|
|
374
374
|
content=content,
|
|
375
|
-
message_type=MessageType.TRAIN,
|
|
376
375
|
dst_node_id=nid,
|
|
376
|
+
message_type=MessageType.TRAIN,
|
|
377
377
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
378
378
|
)
|
|
379
379
|
|
|
@@ -382,7 +382,7 @@ class SecAggPlusWorkflow:
|
|
|
382
382
|
"[Stage 0] Sending configurations to %s clients.",
|
|
383
383
|
len(state.active_node_ids),
|
|
384
384
|
)
|
|
385
|
-
msgs =
|
|
385
|
+
msgs = grid.send_and_receive(
|
|
386
386
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
|
387
387
|
)
|
|
388
388
|
state.active_node_ids = {
|
|
@@ -398,7 +398,7 @@ class SecAggPlusWorkflow:
|
|
|
398
398
|
if msg.has_error():
|
|
399
399
|
state.failures.append(Exception(msg.error))
|
|
400
400
|
continue
|
|
401
|
-
key_dict = msg.content.
|
|
401
|
+
key_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
|
402
402
|
node_id = msg.metadata.src_node_id
|
|
403
403
|
pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2]
|
|
404
404
|
state.nid_to_publickeys[node_id] = [cast(bytes, pk1), cast(bytes, pk2)]
|
|
@@ -406,22 +406,22 @@ class SecAggPlusWorkflow:
|
|
|
406
406
|
return self._check_threshold(state)
|
|
407
407
|
|
|
408
408
|
def share_keys_stage( # pylint: disable=R0914
|
|
409
|
-
self,
|
|
409
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
|
410
410
|
) -> bool:
|
|
411
411
|
"""Execute the 'share keys' stage."""
|
|
412
|
-
cfg = context.state.
|
|
412
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
413
413
|
|
|
414
414
|
def make(nid: int) -> Message:
|
|
415
415
|
neighbours = state.nid_to_neighbours[nid] & state.active_node_ids
|
|
416
|
-
|
|
416
|
+
cfg_record = ConfigRecord(
|
|
417
417
|
{str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
|
|
418
418
|
)
|
|
419
|
-
|
|
420
|
-
content =
|
|
421
|
-
return
|
|
419
|
+
cfg_record[Key.STAGE] = Stage.SHARE_KEYS
|
|
420
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
|
|
421
|
+
return Message(
|
|
422
422
|
content=content,
|
|
423
|
-
message_type=MessageType.TRAIN,
|
|
424
423
|
dst_node_id=nid,
|
|
424
|
+
message_type=MessageType.TRAIN,
|
|
425
425
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
426
426
|
)
|
|
427
427
|
|
|
@@ -431,7 +431,7 @@ class SecAggPlusWorkflow:
|
|
|
431
431
|
"[Stage 1] Forwarding public keys to %s clients.",
|
|
432
432
|
len(state.active_node_ids),
|
|
433
433
|
)
|
|
434
|
-
msgs =
|
|
434
|
+
msgs = grid.send_and_receive(
|
|
435
435
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
|
436
436
|
)
|
|
437
437
|
state.active_node_ids = {
|
|
@@ -458,7 +458,7 @@ class SecAggPlusWorkflow:
|
|
|
458
458
|
state.failures.append(Exception(msg.error))
|
|
459
459
|
continue
|
|
460
460
|
node_id = msg.metadata.src_node_id
|
|
461
|
-
res_dict = msg.content.
|
|
461
|
+
res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
|
462
462
|
dst_lst = cast(list[int], res_dict[Key.DESTINATION_LIST])
|
|
463
463
|
ctxt_lst = cast(list[bytes], res_dict[Key.CIPHERTEXT_LIST])
|
|
464
464
|
srcs += [node_id] * len(dst_lst)
|
|
@@ -476,25 +476,25 @@ class SecAggPlusWorkflow:
|
|
|
476
476
|
return self._check_threshold(state)
|
|
477
477
|
|
|
478
478
|
def collect_masked_vectors_stage(
|
|
479
|
-
self,
|
|
479
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
|
480
480
|
) -> bool:
|
|
481
481
|
"""Execute the 'collect masked vectors' stage."""
|
|
482
|
-
cfg = context.state.
|
|
482
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
483
483
|
|
|
484
484
|
# Send secret key shares to clients (plus FitIns) and collect masked vectors
|
|
485
485
|
def make(nid: int) -> Message:
|
|
486
|
-
|
|
486
|
+
cfg_dict = {
|
|
487
487
|
Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
|
|
488
488
|
Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
|
|
489
489
|
Key.SOURCE_LIST: state.forward_srcs[nid],
|
|
490
490
|
}
|
|
491
|
-
|
|
491
|
+
cfg_record = ConfigRecord(cfg_dict) # type: ignore
|
|
492
492
|
content = state.nid_to_fitins[nid]
|
|
493
|
-
content.
|
|
494
|
-
return
|
|
493
|
+
content.config_records[RECORD_KEY_CONFIGS] = cfg_record
|
|
494
|
+
return Message(
|
|
495
495
|
content=content,
|
|
496
|
-
message_type=MessageType.TRAIN,
|
|
497
496
|
dst_node_id=nid,
|
|
497
|
+
message_type=MessageType.TRAIN,
|
|
498
498
|
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
|
499
499
|
)
|
|
500
500
|
|
|
@@ -503,7 +503,7 @@ class SecAggPlusWorkflow:
|
|
|
503
503
|
"[Stage 2] Forwarding encrypted key shares to %s clients.",
|
|
504
504
|
len(state.active_node_ids),
|
|
505
505
|
)
|
|
506
|
-
msgs =
|
|
506
|
+
msgs = grid.send_and_receive(
|
|
507
507
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
|
508
508
|
)
|
|
509
509
|
state.active_node_ids = {
|
|
@@ -524,7 +524,7 @@ class SecAggPlusWorkflow:
|
|
|
524
524
|
if msg.has_error():
|
|
525
525
|
state.failures.append(Exception(msg.error))
|
|
526
526
|
continue
|
|
527
|
-
res_dict = msg.content.
|
|
527
|
+
res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
|
528
528
|
bytes_list = cast(list[bytes], res_dict[Key.MASKED_PARAMETERS])
|
|
529
529
|
client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
|
|
530
530
|
if masked_vector is None:
|
|
@@ -540,17 +540,17 @@ class SecAggPlusWorkflow:
|
|
|
540
540
|
if msg.has_error():
|
|
541
541
|
state.failures.append(Exception(msg.error))
|
|
542
542
|
continue
|
|
543
|
-
fitres = compat.
|
|
543
|
+
fitres = compat.recorddict_to_fitres(msg.content, True)
|
|
544
544
|
proxy = state.nid_to_proxies[msg.metadata.src_node_id]
|
|
545
545
|
state.legacy_results.append((proxy, fitres))
|
|
546
546
|
|
|
547
547
|
return self._check_threshold(state)
|
|
548
548
|
|
|
549
549
|
def unmask_stage( # pylint: disable=R0912, R0914, R0915
|
|
550
|
-
self,
|
|
550
|
+
self, grid: Grid, context: LegacyContext, state: WorkflowState
|
|
551
551
|
) -> bool:
|
|
552
552
|
"""Execute the 'unmask' stage."""
|
|
553
|
-
cfg = context.state.
|
|
553
|
+
cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
|
|
554
554
|
current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
|
|
555
555
|
|
|
556
556
|
# Construct active node IDs and dead node IDs
|
|
@@ -560,17 +560,17 @@ class SecAggPlusWorkflow:
|
|
|
560
560
|
# Send secure IDs of active and dead clients and collect key shares from clients
|
|
561
561
|
def make(nid: int) -> Message:
|
|
562
562
|
neighbours = state.nid_to_neighbours[nid]
|
|
563
|
-
|
|
563
|
+
cfg_dict = {
|
|
564
564
|
Key.STAGE: Stage.UNMASK,
|
|
565
565
|
Key.ACTIVE_NODE_ID_LIST: list(neighbours & active_nids),
|
|
566
566
|
Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
|
|
567
567
|
}
|
|
568
|
-
|
|
569
|
-
content =
|
|
570
|
-
return
|
|
568
|
+
cfg_record = ConfigRecord(cfg_dict) # type: ignore
|
|
569
|
+
content = RecordDict({RECORD_KEY_CONFIGS: cfg_record})
|
|
570
|
+
return Message(
|
|
571
571
|
content=content,
|
|
572
|
-
message_type=MessageType.TRAIN,
|
|
573
572
|
dst_node_id=nid,
|
|
573
|
+
message_type=MessageType.TRAIN,
|
|
574
574
|
group_id=str(current_round),
|
|
575
575
|
)
|
|
576
576
|
|
|
@@ -579,7 +579,7 @@ class SecAggPlusWorkflow:
|
|
|
579
579
|
"[Stage 3] Requesting key shares from %s clients to remove masks.",
|
|
580
580
|
len(state.active_node_ids),
|
|
581
581
|
)
|
|
582
|
-
msgs =
|
|
582
|
+
msgs = grid.send_and_receive(
|
|
583
583
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
|
584
584
|
)
|
|
585
585
|
state.active_node_ids = {
|
|
@@ -599,7 +599,7 @@ class SecAggPlusWorkflow:
|
|
|
599
599
|
if msg.has_error():
|
|
600
600
|
state.failures.append(Exception(msg.error))
|
|
601
601
|
continue
|
|
602
|
-
res_dict = msg.content.
|
|
602
|
+
res_dict = msg.content.config_records[RECORD_KEY_CONFIGS]
|
|
603
603
|
nids = cast(list[int], res_dict[Key.NODE_ID_LIST])
|
|
604
604
|
shares = cast(list[bytes], res_dict[Key.SHARE_LIST])
|
|
605
605
|
for owner_nid, share in zip(nids, shares):
|
|
@@ -676,10 +676,8 @@ class SecAggPlusWorkflow:
|
|
|
676
676
|
|
|
677
677
|
# Update the parameters and write history
|
|
678
678
|
if parameters_aggregated:
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
)
|
|
682
|
-
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
|
679
|
+
arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
|
|
680
|
+
context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
|
|
683
681
|
context.history.add_metrics_distributed_fit(
|
|
684
682
|
server_round=current_round, metrics=metrics_aggregated
|
|
685
683
|
)
|
flwr/simulation/app.py
CHANGED
|
@@ -47,7 +47,7 @@ from flwr.common.logger import (
|
|
|
47
47
|
stop_log_uploader,
|
|
48
48
|
)
|
|
49
49
|
from flwr.common.serde import (
|
|
50
|
-
|
|
50
|
+
config_record_from_proto,
|
|
51
51
|
context_from_proto,
|
|
52
52
|
context_to_proto,
|
|
53
53
|
fab_from_proto,
|
|
@@ -184,7 +184,7 @@ def run_simulation_process( # pylint: disable=R0914, disable=W0212, disable=R09
|
|
|
184
184
|
fed_opt_res: GetFederationOptionsResponse = conn._stub.GetFederationOptions(
|
|
185
185
|
GetFederationOptionsRequest(run_id=run.run_id)
|
|
186
186
|
)
|
|
187
|
-
federation_options =
|
|
187
|
+
federation_options = config_record_from_proto(
|
|
188
188
|
fed_opt_res.federation_options
|
|
189
189
|
)
|
|
190
190
|
|
|
@@ -105,8 +105,10 @@ def pool_size_from_resources(client_resources: dict[str, Union[int, float]]) ->
|
|
|
105
105
|
if not node_resources:
|
|
106
106
|
continue
|
|
107
107
|
|
|
108
|
-
|
|
109
|
-
|
|
108
|
+
# Fallback to zero when resource quantity is not configured on the ray node
|
|
109
|
+
# e.g.: node without GPU; head node set up not to run tasks (zero resources)
|
|
110
|
+
num_cpus = node_resources.get("CPU", 0)
|
|
111
|
+
num_gpus = node_resources.get("GPU", 0)
|
|
110
112
|
num_actors = int(num_cpus / client_resources["num_cpus"])
|
|
111
113
|
|
|
112
114
|
# If a GPU is present and client resources do require one
|