flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/flower_toml.py +4 -48
- flwr/cli/new/new.py +6 -3
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
- flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
- flwr/cli/utils.py +14 -1
- flwr/client/app.py +39 -5
- flwr/client/client_app.py +1 -47
- flwr/client/mod/__init__.py +2 -1
- flwr/client/mod/secure_aggregation/__init__.py +2 -0
- flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
- flwr/common/grpc.py +3 -3
- flwr/common/logger.py +78 -15
- flwr/common/object_ref.py +140 -0
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
- flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
- flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
- flwr/server/compat/app.py +2 -1
- flwr/server/driver/grpc_driver.py +4 -4
- flwr/server/history.py +22 -15
- flwr/server/run_serverapp.py +22 -4
- flwr/server/server.py +27 -23
- flwr/server/server_app.py +1 -47
- flwr/server/server_config.py +9 -0
- flwr/server/strategy/fedavg.py +2 -0
- flwr/server/superlink/fleet/vce/vce_api.py +9 -2
- flwr/server/superlink/state/in_memory_state.py +34 -32
- flwr/server/workflow/__init__.py +3 -0
- flwr/server/workflow/constant.py +32 -0
- flwr/server/workflow/default_workflows.py +52 -57
- flwr/server/workflow/secure_aggregation/__init__.py +24 -0
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -15,8 +15,9 @@
|
|
15
15
|
"""Legacy default workflows."""
|
16
16
|
|
17
17
|
|
18
|
+
import io
|
18
19
|
import timeit
|
19
|
-
from logging import
|
20
|
+
from logging import INFO
|
20
21
|
from typing import Optional, cast
|
21
22
|
|
22
23
|
import flwr.common.recordset_compat as compat
|
@@ -27,11 +28,7 @@ from ..compat.app_utils import start_update_client_manager_thread
|
|
27
28
|
from ..compat.legacy_context import LegacyContext
|
28
29
|
from ..driver import Driver
|
29
30
|
from ..typing import Workflow
|
30
|
-
|
31
|
-
KEY_CURRENT_ROUND = "current_round"
|
32
|
-
KEY_START_TIME = "start_time"
|
33
|
-
CONFIGS_RECORD_KEY = "config"
|
34
|
-
PARAMS_RECORD_KEY = "parameters"
|
31
|
+
from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key
|
35
32
|
|
36
33
|
|
37
34
|
class DefaultWorkflow:
|
@@ -62,17 +59,19 @@ class DefaultWorkflow:
|
|
62
59
|
)
|
63
60
|
|
64
61
|
# Initialize parameters
|
62
|
+
log(INFO, "[INIT]")
|
65
63
|
default_init_params_workflow(driver, context)
|
66
64
|
|
67
65
|
# Run federated learning for num_rounds
|
68
|
-
log(INFO, "FL starting")
|
69
66
|
start_time = timeit.default_timer()
|
70
67
|
cfg = ConfigsRecord()
|
71
|
-
cfg[
|
72
|
-
context.state.configs_records[
|
68
|
+
cfg[Key.START_TIME] = start_time
|
69
|
+
context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg
|
73
70
|
|
74
71
|
for current_round in range(1, context.config.num_rounds + 1):
|
75
|
-
|
72
|
+
log(INFO, "")
|
73
|
+
log(INFO, "[ROUND %s]", current_round)
|
74
|
+
cfg[Key.CURRENT_ROUND] = current_round
|
76
75
|
|
77
76
|
# Fit round
|
78
77
|
self.fit_workflow(driver, context)
|
@@ -83,22 +82,19 @@ class DefaultWorkflow:
|
|
83
82
|
# Evaluate round
|
84
83
|
self.evaluate_workflow(driver, context)
|
85
84
|
|
86
|
-
# Bookkeeping
|
85
|
+
# Bookkeeping and log results
|
87
86
|
end_time = timeit.default_timer()
|
88
87
|
elapsed = end_time - start_time
|
89
|
-
log(INFO, "FL finished in %s", elapsed)
|
90
|
-
|
91
|
-
# Log results
|
92
88
|
hist = context.history
|
93
|
-
log(INFO, "
|
94
|
-
log(
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
log(INFO, "
|
89
|
+
log(INFO, "")
|
90
|
+
log(INFO, "[SUMMARY]")
|
91
|
+
log(INFO, "Run finished %s rounds in %.2fs", context.config.num_rounds, elapsed)
|
92
|
+
for idx, line in enumerate(io.StringIO(str(hist))):
|
93
|
+
if idx == 0:
|
94
|
+
log(INFO, "%s", line.strip("\n"))
|
95
|
+
else:
|
96
|
+
log(INFO, "\t%s", line.strip("\n"))
|
97
|
+
log(INFO, "")
|
102
98
|
|
103
99
|
# Terminate the thread
|
104
100
|
f_stop.set()
|
@@ -111,12 +107,11 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
111
107
|
if not isinstance(context, LegacyContext):
|
112
108
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
113
109
|
|
114
|
-
log(INFO, "Initializing global parameters")
|
115
110
|
parameters = context.strategy.initialize_parameters(
|
116
111
|
client_manager=context.client_manager
|
117
112
|
)
|
118
113
|
if parameters is not None:
|
119
|
-
log(INFO, "Using initial parameters provided by strategy")
|
114
|
+
log(INFO, "Using initial global parameters provided by strategy")
|
120
115
|
paramsrecord = compat.parameters_to_parametersrecord(
|
121
116
|
parameters, keep_input=True
|
122
117
|
)
|
@@ -141,10 +136,10 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
141
136
|
msg = list(messages)[0]
|
142
137
|
paramsrecord = next(iter(msg.content.parameters_records.values()))
|
143
138
|
|
144
|
-
context.state.parameters_records[
|
139
|
+
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
145
140
|
|
146
141
|
# Evaluate initial parameters
|
147
|
-
log(INFO, "Evaluating initial parameters")
|
142
|
+
log(INFO, "Evaluating initial global parameters")
|
148
143
|
parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
|
149
144
|
res = context.strategy.evaluate(0, parameters=parameters)
|
150
145
|
if res is not None:
|
@@ -164,13 +159,13 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
|
|
164
159
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
165
160
|
|
166
161
|
# Retrieve current_round and start_time from the context
|
167
|
-
cfg = context.state.configs_records[
|
168
|
-
current_round = cast(int, cfg[
|
169
|
-
start_time = cast(float, cfg[
|
162
|
+
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
163
|
+
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
164
|
+
start_time = cast(float, cfg[Key.START_TIME])
|
170
165
|
|
171
166
|
# Centralized evaluation
|
172
167
|
parameters = compat.parametersrecord_to_parameters(
|
173
|
-
record=context.state.parameters_records[
|
168
|
+
record=context.state.parameters_records[MAIN_PARAMS_RECORD],
|
174
169
|
keep_input=True,
|
175
170
|
)
|
176
171
|
res_cen = context.strategy.evaluate(current_round, parameters=parameters)
|
@@ -190,15 +185,17 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
|
|
190
185
|
)
|
191
186
|
|
192
187
|
|
193
|
-
def default_fit_workflow(
|
188
|
+
def default_fit_workflow( # pylint: disable=R0914
|
189
|
+
driver: Driver, context: Context
|
190
|
+
) -> None:
|
194
191
|
"""Execute the default workflow for a single fit round."""
|
195
192
|
if not isinstance(context, LegacyContext):
|
196
193
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
197
194
|
|
198
195
|
# Get current_round and parameters
|
199
|
-
cfg = context.state.configs_records[
|
200
|
-
current_round = cast(int, cfg[
|
201
|
-
parametersrecord = context.state.parameters_records[
|
196
|
+
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
197
|
+
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
198
|
+
parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
|
202
199
|
parameters = compat.parametersrecord_to_parameters(
|
203
200
|
parametersrecord, keep_input=True
|
204
201
|
)
|
@@ -211,12 +208,11 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
|
|
211
208
|
)
|
212
209
|
|
213
210
|
if not client_instructions:
|
214
|
-
log(INFO, "
|
211
|
+
log(INFO, "configure_fit: no clients selected, cancel")
|
215
212
|
return
|
216
213
|
log(
|
217
|
-
|
218
|
-
"
|
219
|
-
current_round,
|
214
|
+
INFO,
|
215
|
+
"configure_fit: strategy sampled %s clients (out of %s)",
|
220
216
|
len(client_instructions),
|
221
217
|
context.client_manager.num_available(),
|
222
218
|
)
|
@@ -240,14 +236,14 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
|
|
240
236
|
# collect `fit` results from all clients participating in this round
|
241
237
|
messages = list(driver.send_and_receive(out_messages))
|
242
238
|
del out_messages
|
239
|
+
num_failures = len([msg for msg in messages if msg.has_error()])
|
243
240
|
|
244
241
|
# No exception/failure handling currently
|
245
242
|
log(
|
246
|
-
|
247
|
-
"
|
248
|
-
|
249
|
-
|
250
|
-
0,
|
243
|
+
INFO,
|
244
|
+
"aggregate_fit: received %s results and %s failures",
|
245
|
+
len(messages) - num_failures,
|
246
|
+
num_failures,
|
251
247
|
)
|
252
248
|
|
253
249
|
# Aggregate training results
|
@@ -266,7 +262,7 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
|
|
266
262
|
paramsrecord = compat.parameters_to_parametersrecord(
|
267
263
|
parameters_aggregated, True
|
268
264
|
)
|
269
|
-
context.state.parameters_records[
|
265
|
+
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
270
266
|
context.history.add_metrics_distributed_fit(
|
271
267
|
server_round=current_round, metrics=metrics_aggregated
|
272
268
|
)
|
@@ -278,9 +274,9 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
278
274
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
279
275
|
|
280
276
|
# Get current_round and parameters
|
281
|
-
cfg = context.state.configs_records[
|
282
|
-
current_round = cast(int, cfg[
|
283
|
-
parametersrecord = context.state.parameters_records[
|
277
|
+
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
278
|
+
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
279
|
+
parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
|
284
280
|
parameters = compat.parametersrecord_to_parameters(
|
285
281
|
parametersrecord, keep_input=True
|
286
282
|
)
|
@@ -292,12 +288,11 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
292
288
|
client_manager=context.client_manager,
|
293
289
|
)
|
294
290
|
if not client_instructions:
|
295
|
-
log(INFO, "
|
291
|
+
log(INFO, "configure_evaluate: no clients selected, skipping evaluation")
|
296
292
|
return
|
297
293
|
log(
|
298
|
-
|
299
|
-
"
|
300
|
-
current_round,
|
294
|
+
INFO,
|
295
|
+
"configure_evaluate: strategy sampled %s clients (out of %s)",
|
301
296
|
len(client_instructions),
|
302
297
|
context.client_manager.num_available(),
|
303
298
|
)
|
@@ -321,14 +316,14 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
321
316
|
# collect `evaluate` results from all clients participating in this round
|
322
317
|
messages = list(driver.send_and_receive(out_messages))
|
323
318
|
del out_messages
|
319
|
+
num_failures = len([msg for msg in messages if msg.has_error()])
|
324
320
|
|
325
321
|
# No exception/failure handling currently
|
326
322
|
log(
|
327
|
-
|
328
|
-
"
|
329
|
-
|
330
|
-
|
331
|
-
0,
|
323
|
+
INFO,
|
324
|
+
"aggregate_evaluate: received %s results and %s failures",
|
325
|
+
len(messages) - num_failures,
|
326
|
+
num_failures,
|
332
327
|
)
|
333
328
|
|
334
329
|
# Aggregate the evaluation results
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Secure Aggregation workflows."""
|
16
|
+
|
17
|
+
|
18
|
+
from .secagg_workflow import SecAggWorkflow
|
19
|
+
from .secaggplus_workflow import SecAggPlusWorkflow
|
20
|
+
|
21
|
+
__all__ = [
|
22
|
+
"SecAggPlusWorkflow",
|
23
|
+
"SecAggWorkflow",
|
24
|
+
]
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Workflow for the SecAgg protocol."""
|
16
|
+
|
17
|
+
|
18
|
+
from typing import Optional, Union
|
19
|
+
|
20
|
+
from .secaggplus_workflow import SecAggPlusWorkflow
|
21
|
+
|
22
|
+
|
23
|
+
class SecAggWorkflow(SecAggPlusWorkflow):
|
24
|
+
"""The workflow for the SecAgg protocol.
|
25
|
+
|
26
|
+
The SecAgg protocol ensures the secure summation of integer vectors owned by
|
27
|
+
multiple parties, without accessing any individual integer vector. This workflow
|
28
|
+
allows the server to compute the weighted average of model parameters across all
|
29
|
+
clients, ensuring individual contributions remain private. This is achieved by
|
30
|
+
clients sending both, a weighting factor and a weighted version of the locally
|
31
|
+
updated parameters, both of which are masked for privacy. Specifically, each
|
32
|
+
client uploads "[w, w * params]" with masks, where weighting factor 'w' is the
|
33
|
+
number of examples ('num_examples') and 'params' represents the model parameters
|
34
|
+
('parameters') from the client's `FitRes`. The server then aggregates these
|
35
|
+
contributions to compute the weighted average of model parameters.
|
36
|
+
|
37
|
+
The protocol involves four main stages:
|
38
|
+
- 'setup': Send SecAgg configuration to clients and collect their public keys.
|
39
|
+
- 'share keys': Broadcast public keys among clients and collect encrypted secret
|
40
|
+
key shares.
|
41
|
+
- 'collect masked vectors': Forward encrypted secret key shares to target clients
|
42
|
+
and collect masked model parameters.
|
43
|
+
- 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
|
44
|
+
|
45
|
+
Only the aggregated model parameters are exposed and passed to
|
46
|
+
`Strategy.aggregate_fit`, ensuring individual data privacy.
|
47
|
+
|
48
|
+
Parameters
|
49
|
+
----------
|
50
|
+
reconstruction_threshold : Union[int, float]
|
51
|
+
The minimum number of shares required to reconstruct a client's private key,
|
52
|
+
or, if specified as a float, it represents the proportion of the total number
|
53
|
+
of shares needed for reconstruction. This threshold ensures privacy by allowing
|
54
|
+
for the recovery of contributions from dropped clients during aggregation,
|
55
|
+
without compromising individual client data.
|
56
|
+
max_weight : Optional[float] (default: 1000.0)
|
57
|
+
The maximum value of the weight that can be assigned to any single client's
|
58
|
+
update during the weighted average calculation on the server side, e.g., in the
|
59
|
+
FedAvg algorithm.
|
60
|
+
clipping_range : float, optional (default: 8.0)
|
61
|
+
The range within which model parameters are clipped before quantization.
|
62
|
+
This parameter ensures each model parameter is bounded within
|
63
|
+
[-clipping_range, clipping_range], facilitating quantization.
|
64
|
+
quantization_range : int, optional (default: 4194304, this equals 2**22)
|
65
|
+
The size of the range into which floating-point model parameters are quantized,
|
66
|
+
mapping each parameter to an integer in [0, quantization_range-1]. This
|
67
|
+
facilitates cryptographic operations on the model updates.
|
68
|
+
modulus_range : int, optional (default: 4294967296, this equals 2**32)
|
69
|
+
The range of values from which random mask entries are uniformly sampled
|
70
|
+
([0, modulus_range-1]). `modulus_range` must be less than 4294967296.
|
71
|
+
Please use 2**n values for `modulus_range` to prevent overflow issues.
|
72
|
+
timeout : Optional[float] (default: None)
|
73
|
+
The timeout duration in seconds. If specified, the workflow will wait for
|
74
|
+
replies for this duration each time. If `None`, there is no time limit and
|
75
|
+
the workflow will wait until replies for all messages are received.
|
76
|
+
|
77
|
+
Notes
|
78
|
+
-----
|
79
|
+
- Each client's private key is split into N shares under the SecAgg protocol, where
|
80
|
+
N is the number of selected clients.
|
81
|
+
- Generally, higher `reconstruction_threshold` means better privacy guarantees but
|
82
|
+
less tolerance to dropouts.
|
83
|
+
- Too large `max_weight` may compromise the precision of the quantization.
|
84
|
+
- `modulus_range` must be 2**n and larger than `quantization_range`.
|
85
|
+
- When `reconstruction_threshold` is a float, it is interpreted as the proportion of
|
86
|
+
the number of all selected clients needed for the reconstruction of a private key.
|
87
|
+
This feature enables flexibility in setting the security threshold relative to the
|
88
|
+
number of selected clients.
|
89
|
+
- `reconstruction_threshold`, and the quantization parameters
|
90
|
+
(`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in
|
91
|
+
balancing privacy, robustness, and efficiency within the SecAgg protocol.
|
92
|
+
"""
|
93
|
+
|
94
|
+
def __init__( # pylint: disable=R0913
|
95
|
+
self,
|
96
|
+
reconstruction_threshold: Union[int, float],
|
97
|
+
*,
|
98
|
+
max_weight: float = 1000.0,
|
99
|
+
clipping_range: float = 8.0,
|
100
|
+
quantization_range: int = 4194304,
|
101
|
+
modulus_range: int = 4294967296,
|
102
|
+
timeout: Optional[float] = None,
|
103
|
+
) -> None:
|
104
|
+
super().__init__(
|
105
|
+
num_shares=1.0,
|
106
|
+
reconstruction_threshold=reconstruction_threshold,
|
107
|
+
max_weight=max_weight,
|
108
|
+
clipping_range=clipping_range,
|
109
|
+
quantization_range=quantization_range,
|
110
|
+
modulus_range=modulus_range,
|
111
|
+
timeout=timeout,
|
112
|
+
)
|