flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/cli/flower_toml.py +4 -48
- flwr/cli/new/new.py +6 -3
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
- flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
- flwr/cli/utils.py +14 -1
- flwr/client/app.py +39 -5
- flwr/client/client_app.py +1 -47
- flwr/client/mod/__init__.py +2 -1
- flwr/client/mod/secure_aggregation/__init__.py +2 -0
- flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
- flwr/common/grpc.py +3 -3
- flwr/common/logger.py +78 -15
- flwr/common/object_ref.py +140 -0
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
- flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
- flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
- flwr/server/compat/app.py +2 -1
- flwr/server/driver/grpc_driver.py +4 -4
- flwr/server/history.py +22 -15
- flwr/server/run_serverapp.py +22 -4
- flwr/server/server.py +27 -23
- flwr/server/server_app.py +1 -47
- flwr/server/server_config.py +9 -0
- flwr/server/strategy/fedavg.py +2 -0
- flwr/server/superlink/fleet/vce/vce_api.py +9 -2
- flwr/server/superlink/state/in_memory_state.py +34 -32
- flwr/server/workflow/__init__.py +3 -0
- flwr/server/workflow/constant.py +32 -0
- flwr/server/workflow/default_workflows.py +52 -57
- flwr/server/workflow/secure_aggregation/__init__.py +24 -0
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -15,8 +15,9 @@
|
|
15
15
|
"""Legacy default workflows."""
|
16
16
|
|
17
17
|
|
18
|
+
import io
|
18
19
|
import timeit
|
19
|
-
from logging import
|
20
|
+
from logging import INFO
|
20
21
|
from typing import Optional, cast
|
21
22
|
|
22
23
|
import flwr.common.recordset_compat as compat
|
@@ -27,11 +28,7 @@ from ..compat.app_utils import start_update_client_manager_thread
|
|
27
28
|
from ..compat.legacy_context import LegacyContext
|
28
29
|
from ..driver import Driver
|
29
30
|
from ..typing import Workflow
|
30
|
-
|
31
|
-
KEY_CURRENT_ROUND = "current_round"
|
32
|
-
KEY_START_TIME = "start_time"
|
33
|
-
CONFIGS_RECORD_KEY = "config"
|
34
|
-
PARAMS_RECORD_KEY = "parameters"
|
31
|
+
from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key
|
35
32
|
|
36
33
|
|
37
34
|
class DefaultWorkflow:
|
@@ -62,17 +59,19 @@ class DefaultWorkflow:
|
|
62
59
|
)
|
63
60
|
|
64
61
|
# Initialize parameters
|
62
|
+
log(INFO, "[INIT]")
|
65
63
|
default_init_params_workflow(driver, context)
|
66
64
|
|
67
65
|
# Run federated learning for num_rounds
|
68
|
-
log(INFO, "FL starting")
|
69
66
|
start_time = timeit.default_timer()
|
70
67
|
cfg = ConfigsRecord()
|
71
|
-
cfg[
|
72
|
-
context.state.configs_records[
|
68
|
+
cfg[Key.START_TIME] = start_time
|
69
|
+
context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg
|
73
70
|
|
74
71
|
for current_round in range(1, context.config.num_rounds + 1):
|
75
|
-
|
72
|
+
log(INFO, "")
|
73
|
+
log(INFO, "[ROUND %s]", current_round)
|
74
|
+
cfg[Key.CURRENT_ROUND] = current_round
|
76
75
|
|
77
76
|
# Fit round
|
78
77
|
self.fit_workflow(driver, context)
|
@@ -83,22 +82,19 @@ class DefaultWorkflow:
|
|
83
82
|
# Evaluate round
|
84
83
|
self.evaluate_workflow(driver, context)
|
85
84
|
|
86
|
-
# Bookkeeping
|
85
|
+
# Bookkeeping and log results
|
87
86
|
end_time = timeit.default_timer()
|
88
87
|
elapsed = end_time - start_time
|
89
|
-
log(INFO, "FL finished in %s", elapsed)
|
90
|
-
|
91
|
-
# Log results
|
92
88
|
hist = context.history
|
93
|
-
log(INFO, "
|
94
|
-
log(
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
log(INFO, "
|
89
|
+
log(INFO, "")
|
90
|
+
log(INFO, "[SUMMARY]")
|
91
|
+
log(INFO, "Run finished %s rounds in %.2fs", context.config.num_rounds, elapsed)
|
92
|
+
for idx, line in enumerate(io.StringIO(str(hist))):
|
93
|
+
if idx == 0:
|
94
|
+
log(INFO, "%s", line.strip("\n"))
|
95
|
+
else:
|
96
|
+
log(INFO, "\t%s", line.strip("\n"))
|
97
|
+
log(INFO, "")
|
102
98
|
|
103
99
|
# Terminate the thread
|
104
100
|
f_stop.set()
|
@@ -111,12 +107,11 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
111
107
|
if not isinstance(context, LegacyContext):
|
112
108
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
113
109
|
|
114
|
-
log(INFO, "Initializing global parameters")
|
115
110
|
parameters = context.strategy.initialize_parameters(
|
116
111
|
client_manager=context.client_manager
|
117
112
|
)
|
118
113
|
if parameters is not None:
|
119
|
-
log(INFO, "Using initial parameters provided by strategy")
|
114
|
+
log(INFO, "Using initial global parameters provided by strategy")
|
120
115
|
paramsrecord = compat.parameters_to_parametersrecord(
|
121
116
|
parameters, keep_input=True
|
122
117
|
)
|
@@ -141,10 +136,10 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
141
136
|
msg = list(messages)[0]
|
142
137
|
paramsrecord = next(iter(msg.content.parameters_records.values()))
|
143
138
|
|
144
|
-
context.state.parameters_records[
|
139
|
+
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
145
140
|
|
146
141
|
# Evaluate initial parameters
|
147
|
-
log(INFO, "Evaluating initial parameters")
|
142
|
+
log(INFO, "Evaluating initial global parameters")
|
148
143
|
parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
|
149
144
|
res = context.strategy.evaluate(0, parameters=parameters)
|
150
145
|
if res is not None:
|
@@ -164,13 +159,13 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
|
|
164
159
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
165
160
|
|
166
161
|
# Retrieve current_round and start_time from the context
|
167
|
-
cfg = context.state.configs_records[
|
168
|
-
current_round = cast(int, cfg[
|
169
|
-
start_time = cast(float, cfg[
|
162
|
+
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
163
|
+
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
164
|
+
start_time = cast(float, cfg[Key.START_TIME])
|
170
165
|
|
171
166
|
# Centralized evaluation
|
172
167
|
parameters = compat.parametersrecord_to_parameters(
|
173
|
-
record=context.state.parameters_records[
|
168
|
+
record=context.state.parameters_records[MAIN_PARAMS_RECORD],
|
174
169
|
keep_input=True,
|
175
170
|
)
|
176
171
|
res_cen = context.strategy.evaluate(current_round, parameters=parameters)
|
@@ -190,15 +185,17 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
|
|
190
185
|
)
|
191
186
|
|
192
187
|
|
193
|
-
def default_fit_workflow(
|
188
|
+
def default_fit_workflow( # pylint: disable=R0914
|
189
|
+
driver: Driver, context: Context
|
190
|
+
) -> None:
|
194
191
|
"""Execute the default workflow for a single fit round."""
|
195
192
|
if not isinstance(context, LegacyContext):
|
196
193
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
197
194
|
|
198
195
|
# Get current_round and parameters
|
199
|
-
cfg = context.state.configs_records[
|
200
|
-
current_round = cast(int, cfg[
|
201
|
-
parametersrecord = context.state.parameters_records[
|
196
|
+
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
197
|
+
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
198
|
+
parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
|
202
199
|
parameters = compat.parametersrecord_to_parameters(
|
203
200
|
parametersrecord, keep_input=True
|
204
201
|
)
|
@@ -211,12 +208,11 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
|
|
211
208
|
)
|
212
209
|
|
213
210
|
if not client_instructions:
|
214
|
-
log(INFO, "
|
211
|
+
log(INFO, "configure_fit: no clients selected, cancel")
|
215
212
|
return
|
216
213
|
log(
|
217
|
-
|
218
|
-
"
|
219
|
-
current_round,
|
214
|
+
INFO,
|
215
|
+
"configure_fit: strategy sampled %s clients (out of %s)",
|
220
216
|
len(client_instructions),
|
221
217
|
context.client_manager.num_available(),
|
222
218
|
)
|
@@ -240,14 +236,14 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
|
|
240
236
|
# collect `fit` results from all clients participating in this round
|
241
237
|
messages = list(driver.send_and_receive(out_messages))
|
242
238
|
del out_messages
|
239
|
+
num_failures = len([msg for msg in messages if msg.has_error()])
|
243
240
|
|
244
241
|
# No exception/failure handling currently
|
245
242
|
log(
|
246
|
-
|
247
|
-
"
|
248
|
-
|
249
|
-
|
250
|
-
0,
|
243
|
+
INFO,
|
244
|
+
"aggregate_fit: received %s results and %s failures",
|
245
|
+
len(messages) - num_failures,
|
246
|
+
num_failures,
|
251
247
|
)
|
252
248
|
|
253
249
|
# Aggregate training results
|
@@ -266,7 +262,7 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
|
|
266
262
|
paramsrecord = compat.parameters_to_parametersrecord(
|
267
263
|
parameters_aggregated, True
|
268
264
|
)
|
269
|
-
context.state.parameters_records[
|
265
|
+
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
270
266
|
context.history.add_metrics_distributed_fit(
|
271
267
|
server_round=current_round, metrics=metrics_aggregated
|
272
268
|
)
|
@@ -278,9 +274,9 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
278
274
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
279
275
|
|
280
276
|
# Get current_round and parameters
|
281
|
-
cfg = context.state.configs_records[
|
282
|
-
current_round = cast(int, cfg[
|
283
|
-
parametersrecord = context.state.parameters_records[
|
277
|
+
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
278
|
+
current_round = cast(int, cfg[Key.CURRENT_ROUND])
|
279
|
+
parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
|
284
280
|
parameters = compat.parametersrecord_to_parameters(
|
285
281
|
parametersrecord, keep_input=True
|
286
282
|
)
|
@@ -292,12 +288,11 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
292
288
|
client_manager=context.client_manager,
|
293
289
|
)
|
294
290
|
if not client_instructions:
|
295
|
-
log(INFO, "
|
291
|
+
log(INFO, "configure_evaluate: no clients selected, skipping evaluation")
|
296
292
|
return
|
297
293
|
log(
|
298
|
-
|
299
|
-
"
|
300
|
-
current_round,
|
294
|
+
INFO,
|
295
|
+
"configure_evaluate: strategy sampled %s clients (out of %s)",
|
301
296
|
len(client_instructions),
|
302
297
|
context.client_manager.num_available(),
|
303
298
|
)
|
@@ -321,14 +316,14 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
321
316
|
# collect `evaluate` results from all clients participating in this round
|
322
317
|
messages = list(driver.send_and_receive(out_messages))
|
323
318
|
del out_messages
|
319
|
+
num_failures = len([msg for msg in messages if msg.has_error()])
|
324
320
|
|
325
321
|
# No exception/failure handling currently
|
326
322
|
log(
|
327
|
-
|
328
|
-
"
|
329
|
-
|
330
|
-
|
331
|
-
0,
|
323
|
+
INFO,
|
324
|
+
"aggregate_evaluate: received %s results and %s failures",
|
325
|
+
len(messages) - num_failures,
|
326
|
+
num_failures,
|
332
327
|
)
|
333
328
|
|
334
329
|
# Aggregate the evaluation results
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Secure Aggregation workflows."""
|
16
|
+
|
17
|
+
|
18
|
+
from .secagg_workflow import SecAggWorkflow
|
19
|
+
from .secaggplus_workflow import SecAggPlusWorkflow
|
20
|
+
|
21
|
+
__all__ = [
|
22
|
+
"SecAggPlusWorkflow",
|
23
|
+
"SecAggWorkflow",
|
24
|
+
]
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Workflow for the SecAgg protocol."""
|
16
|
+
|
17
|
+
|
18
|
+
from typing import Optional, Union
|
19
|
+
|
20
|
+
from .secaggplus_workflow import SecAggPlusWorkflow
|
21
|
+
|
22
|
+
|
23
|
+
class SecAggWorkflow(SecAggPlusWorkflow):
|
24
|
+
"""The workflow for the SecAgg protocol.
|
25
|
+
|
26
|
+
The SecAgg protocol ensures the secure summation of integer vectors owned by
|
27
|
+
multiple parties, without accessing any individual integer vector. This workflow
|
28
|
+
allows the server to compute the weighted average of model parameters across all
|
29
|
+
clients, ensuring individual contributions remain private. This is achieved by
|
30
|
+
clients sending both, a weighting factor and a weighted version of the locally
|
31
|
+
updated parameters, both of which are masked for privacy. Specifically, each
|
32
|
+
client uploads "[w, w * params]" with masks, where weighting factor 'w' is the
|
33
|
+
number of examples ('num_examples') and 'params' represents the model parameters
|
34
|
+
('parameters') from the client's `FitRes`. The server then aggregates these
|
35
|
+
contributions to compute the weighted average of model parameters.
|
36
|
+
|
37
|
+
The protocol involves four main stages:
|
38
|
+
- 'setup': Send SecAgg configuration to clients and collect their public keys.
|
39
|
+
- 'share keys': Broadcast public keys among clients and collect encrypted secret
|
40
|
+
key shares.
|
41
|
+
- 'collect masked vectors': Forward encrypted secret key shares to target clients
|
42
|
+
and collect masked model parameters.
|
43
|
+
- 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
|
44
|
+
|
45
|
+
Only the aggregated model parameters are exposed and passed to
|
46
|
+
`Strategy.aggregate_fit`, ensuring individual data privacy.
|
47
|
+
|
48
|
+
Parameters
|
49
|
+
----------
|
50
|
+
reconstruction_threshold : Union[int, float]
|
51
|
+
The minimum number of shares required to reconstruct a client's private key,
|
52
|
+
or, if specified as a float, it represents the proportion of the total number
|
53
|
+
of shares needed for reconstruction. This threshold ensures privacy by allowing
|
54
|
+
for the recovery of contributions from dropped clients during aggregation,
|
55
|
+
without compromising individual client data.
|
56
|
+
max_weight : Optional[float] (default: 1000.0)
|
57
|
+
The maximum value of the weight that can be assigned to any single client's
|
58
|
+
update during the weighted average calculation on the server side, e.g., in the
|
59
|
+
FedAvg algorithm.
|
60
|
+
clipping_range : float, optional (default: 8.0)
|
61
|
+
The range within which model parameters are clipped before quantization.
|
62
|
+
This parameter ensures each model parameter is bounded within
|
63
|
+
[-clipping_range, clipping_range], facilitating quantization.
|
64
|
+
quantization_range : int, optional (default: 4194304, this equals 2**22)
|
65
|
+
The size of the range into which floating-point model parameters are quantized,
|
66
|
+
mapping each parameter to an integer in [0, quantization_range-1]. This
|
67
|
+
facilitates cryptographic operations on the model updates.
|
68
|
+
modulus_range : int, optional (default: 4294967296, this equals 2**32)
|
69
|
+
The range of values from which random mask entries are uniformly sampled
|
70
|
+
([0, modulus_range-1]). `modulus_range` must be less than 4294967296.
|
71
|
+
Please use 2**n values for `modulus_range` to prevent overflow issues.
|
72
|
+
timeout : Optional[float] (default: None)
|
73
|
+
The timeout duration in seconds. If specified, the workflow will wait for
|
74
|
+
replies for this duration each time. If `None`, there is no time limit and
|
75
|
+
the workflow will wait until replies for all messages are received.
|
76
|
+
|
77
|
+
Notes
|
78
|
+
-----
|
79
|
+
- Each client's private key is split into N shares under the SecAgg protocol, where
|
80
|
+
N is the number of selected clients.
|
81
|
+
- Generally, higher `reconstruction_threshold` means better privacy guarantees but
|
82
|
+
less tolerance to dropouts.
|
83
|
+
- Too large `max_weight` may compromise the precision of the quantization.
|
84
|
+
- `modulus_range` must be 2**n and larger than `quantization_range`.
|
85
|
+
- When `reconstruction_threshold` is a float, it is interpreted as the proportion of
|
86
|
+
the number of all selected clients needed for the reconstruction of a private key.
|
87
|
+
This feature enables flexibility in setting the security threshold relative to the
|
88
|
+
number of selected clients.
|
89
|
+
- `reconstruction_threshold`, and the quantization parameters
|
90
|
+
(`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in
|
91
|
+
balancing privacy, robustness, and efficiency within the SecAgg protocol.
|
92
|
+
"""
|
93
|
+
|
94
|
+
def __init__( # pylint: disable=R0913
|
95
|
+
self,
|
96
|
+
reconstruction_threshold: Union[int, float],
|
97
|
+
*,
|
98
|
+
max_weight: float = 1000.0,
|
99
|
+
clipping_range: float = 8.0,
|
100
|
+
quantization_range: int = 4194304,
|
101
|
+
modulus_range: int = 4294967296,
|
102
|
+
timeout: Optional[float] = None,
|
103
|
+
) -> None:
|
104
|
+
super().__init__(
|
105
|
+
num_shares=1.0,
|
106
|
+
reconstruction_threshold=reconstruction_threshold,
|
107
|
+
max_weight=max_weight,
|
108
|
+
clipping_range=clipping_range,
|
109
|
+
quantization_range=quantization_range,
|
110
|
+
modulus_range=modulus_range,
|
111
|
+
timeout=timeout,
|
112
|
+
)
|