flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/mod/__init__.py +3 -2
- flwr/client/mod/centraldp_mods.py +63 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
- flwr/common/differential_privacy.py +77 -0
- flwr/common/differential_privacy_constants.py +1 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
- flwr/proto/error_pb2.py +26 -0
- flwr/proto/error_pb2.pyi +25 -0
- flwr/proto/error_pb2_grpc.py +4 -0
- flwr/proto/error_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +8 -7
- flwr/proto/task_pb2.pyi +7 -2
- flwr/server/__init__.py +4 -0
- flwr/server/app.py +8 -31
- flwr/server/client_proxy.py +5 -0
- flwr/server/compat/__init__.py +2 -0
- flwr/server/compat/app.py +7 -88
- flwr/server/compat/app_utils.py +102 -0
- flwr/server/compat/driver_client_proxy.py +22 -10
- flwr/server/compat/legacy_context.py +55 -0
- flwr/server/run_serverapp.py +1 -1
- flwr/server/server.py +18 -8
- flwr/server/strategy/__init__.py +24 -14
- flwr/server/strategy/dp_adaptive_clipping.py +449 -0
- flwr/server/strategy/dp_fixed_clipping.py +5 -7
- flwr/server/superlink/driver/driver_grpc.py +54 -0
- flwr/server/superlink/driver/driver_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
- flwr/server/superlink/fleet/vce/vce_api.py +236 -16
- flwr/server/typing.py +1 -0
- flwr/server/workflow/__init__.py +22 -0
- flwr/server/workflow/default_workflows.py +357 -0
- flwr/simulation/__init__.py +3 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
- flwr/simulation/run_simulation.py +177 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
flwr/client/mod/__init__.py
CHANGED
@@ -15,12 +15,13 @@
|
|
15
15
|
"""Mods."""
|
16
16
|
|
17
17
|
|
18
|
-
from .centraldp_mods import fixedclipping_mod
|
18
|
+
from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
|
19
19
|
from .secure_aggregation.secaggplus_mod import secaggplus_mod
|
20
20
|
from .utils import make_ffn
|
21
21
|
|
22
22
|
__all__ = [
|
23
|
+
"adaptiveclipping_mod",
|
24
|
+
"fixedclipping_mod",
|
23
25
|
"make_ffn",
|
24
26
|
"secaggplus_mod",
|
25
|
-
"fixedclipping_mod",
|
26
27
|
]
|
@@ -20,8 +20,11 @@ from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
|
|
20
20
|
from flwr.common import recordset_compat as compat
|
21
21
|
from flwr.common.constant import MESSAGE_TYPE_FIT
|
22
22
|
from flwr.common.context import Context
|
23
|
-
from flwr.common.differential_privacy import
|
24
|
-
|
23
|
+
from flwr.common.differential_privacy import (
|
24
|
+
compute_adaptive_clip_model_update,
|
25
|
+
compute_clip_model_update,
|
26
|
+
)
|
27
|
+
from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM, KEY_NORM_BIT
|
25
28
|
from flwr.common.message import Message
|
26
29
|
|
27
30
|
|
@@ -74,3 +77,61 @@ def fixedclipping_mod(
|
|
74
77
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
75
78
|
out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
|
76
79
|
return out_msg
|
80
|
+
|
81
|
+
|
82
|
+
def adaptiveclipping_mod(
|
83
|
+
msg: Message, ctxt: Context, call_next: ClientAppCallable
|
84
|
+
) -> Message:
|
85
|
+
"""Client-side adaptive clipping modifier.
|
86
|
+
|
87
|
+
This mod needs to be used with the DifferentialPrivacyClientSideAdaptiveClipping
|
88
|
+
server-side strategy wrapper.
|
89
|
+
|
90
|
+
The wrapper sends the clipping_norm value to the client.
|
91
|
+
|
92
|
+
This mod clips the client model updates before sending them to the server.
|
93
|
+
|
94
|
+
It also sends KEY_NORM_BIT to the server for computing the new clipping value.
|
95
|
+
|
96
|
+
It operates on messages with type MESSAGE_TYPE_FIT.
|
97
|
+
|
98
|
+
Notes
|
99
|
+
-----
|
100
|
+
Consider the order of mods when using multiple.
|
101
|
+
|
102
|
+
Typically, adaptiveclipping_mod should be the last to operate on params.
|
103
|
+
"""
|
104
|
+
if msg.metadata.message_type != MESSAGE_TYPE_FIT:
|
105
|
+
return call_next(msg, ctxt)
|
106
|
+
|
107
|
+
fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
|
108
|
+
|
109
|
+
if KEY_CLIPPING_NORM not in fit_ins.config:
|
110
|
+
raise KeyError(
|
111
|
+
f"The {KEY_CLIPPING_NORM} value is not supplied by the "
|
112
|
+
f"DifferentialPrivacyClientSideFixedClipping wrapper at"
|
113
|
+
f" the server side."
|
114
|
+
)
|
115
|
+
if not isinstance(fit_ins.config[KEY_CLIPPING_NORM], float):
|
116
|
+
raise ValueError(f"{KEY_CLIPPING_NORM} should be a float value.")
|
117
|
+
clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM])
|
118
|
+
server_to_client_params = parameters_to_ndarrays(fit_ins.parameters)
|
119
|
+
|
120
|
+
# Call inner app
|
121
|
+
out_msg = call_next(msg, ctxt)
|
122
|
+
fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True)
|
123
|
+
|
124
|
+
client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
|
125
|
+
|
126
|
+
# Clip the client update
|
127
|
+
norm_bit = compute_adaptive_clip_model_update(
|
128
|
+
client_to_server_params,
|
129
|
+
server_to_client_params,
|
130
|
+
clipping_norm,
|
131
|
+
)
|
132
|
+
|
133
|
+
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
134
|
+
|
135
|
+
fit_res.metrics[KEY_NORM_BIT] = norm_bit
|
136
|
+
out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
|
137
|
+
return out_msg
|
@@ -52,31 +52,10 @@ from flwr.common.secure_aggregation.ndarrays_arithmetic import (
|
|
52
52
|
)
|
53
53
|
from flwr.common.secure_aggregation.quantization import quantize
|
54
54
|
from flwr.common.secure_aggregation.secaggplus_constants import (
|
55
|
-
KEY_ACTIVE_SECURE_ID_LIST,
|
56
|
-
KEY_CIPHERTEXT_LIST,
|
57
|
-
KEY_CLIPPING_RANGE,
|
58
|
-
KEY_DEAD_SECURE_ID_LIST,
|
59
|
-
KEY_DESTINATION_LIST,
|
60
|
-
KEY_MASKED_PARAMETERS,
|
61
|
-
KEY_MOD_RANGE,
|
62
|
-
KEY_PUBLIC_KEY_1,
|
63
|
-
KEY_PUBLIC_KEY_2,
|
64
|
-
KEY_SAMPLE_NUMBER,
|
65
|
-
KEY_SECURE_ID,
|
66
|
-
KEY_SECURE_ID_LIST,
|
67
|
-
KEY_SHARE_LIST,
|
68
|
-
KEY_SHARE_NUMBER,
|
69
|
-
KEY_SOURCE_LIST,
|
70
|
-
KEY_STAGE,
|
71
|
-
KEY_TARGET_RANGE,
|
72
|
-
KEY_THRESHOLD,
|
73
55
|
RECORD_KEY_CONFIGS,
|
74
56
|
RECORD_KEY_STATE,
|
75
|
-
|
76
|
-
|
77
|
-
STAGE_SHARE_KEYS,
|
78
|
-
STAGE_UNMASK,
|
79
|
-
STAGES,
|
57
|
+
Key,
|
58
|
+
Stage,
|
80
59
|
)
|
81
60
|
from flwr.common.secure_aggregation.secaggplus_utils import (
|
82
61
|
pseudo_rand_gen,
|
@@ -91,7 +70,7 @@ from flwr.common.typing import ConfigsRecordValues, FitRes
|
|
91
70
|
class SecAggPlusState:
|
92
71
|
"""State of the SecAgg+ protocol."""
|
93
72
|
|
94
|
-
current_stage: str =
|
73
|
+
current_stage: str = Stage.UNMASK
|
95
74
|
|
96
75
|
sid: int = 0
|
97
76
|
sample_num: int = 0
|
@@ -187,20 +166,20 @@ def secaggplus_mod(
|
|
187
166
|
check_stage(state.current_stage, configs)
|
188
167
|
|
189
168
|
# Update the current stage
|
190
|
-
state.current_stage = cast(str, configs.pop(
|
169
|
+
state.current_stage = cast(str, configs.pop(Key.STAGE))
|
191
170
|
|
192
171
|
# Check the validity of the configs based on the current stage
|
193
172
|
check_configs(state.current_stage, configs)
|
194
173
|
|
195
174
|
# Execute
|
196
|
-
if state.current_stage ==
|
175
|
+
if state.current_stage == Stage.SETUP:
|
197
176
|
res = _setup(state, configs)
|
198
|
-
elif state.current_stage ==
|
177
|
+
elif state.current_stage == Stage.SHARE_KEYS:
|
199
178
|
res = _share_keys(state, configs)
|
200
|
-
elif state.current_stage ==
|
179
|
+
elif state.current_stage == Stage.COLLECT_MASKED_INPUT:
|
201
180
|
fit = _get_fit_fn(msg, ctxt, call_next)
|
202
181
|
res = _collect_masked_input(state, configs, fit)
|
203
|
-
elif state.current_stage ==
|
182
|
+
elif state.current_stage == Stage.UNMASK:
|
204
183
|
res = _unmask(state, configs)
|
205
184
|
else:
|
206
185
|
raise ValueError(f"Unknown secagg stage: {state.current_stage}")
|
@@ -215,28 +194,29 @@ def secaggplus_mod(
|
|
215
194
|
|
216
195
|
def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
217
196
|
"""Check the validity of the next stage."""
|
218
|
-
# Check the existence of
|
219
|
-
if
|
197
|
+
# Check the existence of Config.STAGE
|
198
|
+
if Key.STAGE not in configs:
|
220
199
|
raise KeyError(
|
221
|
-
f"The required key '{
|
200
|
+
f"The required key '{Key.STAGE}' is missing from the input `named_values`."
|
222
201
|
)
|
223
202
|
|
224
|
-
# Check the value type of the
|
225
|
-
next_stage = configs[
|
203
|
+
# Check the value type of the Config.STAGE
|
204
|
+
next_stage = configs[Key.STAGE]
|
226
205
|
if not isinstance(next_stage, str):
|
227
206
|
raise TypeError(
|
228
|
-
f"The value for the key '{
|
207
|
+
f"The value for the key '{Key.STAGE}' must be of type {str}, "
|
229
208
|
f"but got {type(next_stage)} instead."
|
230
209
|
)
|
231
210
|
|
232
211
|
# Check the validity of the next stage
|
233
|
-
if next_stage ==
|
234
|
-
if current_stage !=
|
212
|
+
if next_stage == Stage.SETUP:
|
213
|
+
if current_stage != Stage.UNMASK:
|
235
214
|
log(WARNING, "Restart from the setup stage")
|
236
215
|
# If stage is not "setup",
|
237
216
|
# the stage from `named_values` should be the expected next stage
|
238
217
|
else:
|
239
|
-
|
218
|
+
stages = Stage.all()
|
219
|
+
expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
|
240
220
|
if next_stage != expected_next_stage:
|
241
221
|
raise ValueError(
|
242
222
|
"Abort secure aggregation: "
|
@@ -248,20 +228,20 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
248
228
|
def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
249
229
|
"""Check the validity of the configs."""
|
250
230
|
# Check `named_values` for the setup stage
|
251
|
-
if stage ==
|
231
|
+
if stage == Stage.SETUP:
|
252
232
|
key_type_pairs = [
|
253
|
-
(
|
254
|
-
(
|
255
|
-
(
|
256
|
-
(
|
257
|
-
(
|
258
|
-
(
|
259
|
-
(
|
233
|
+
(Key.SAMPLE_NUMBER, int),
|
234
|
+
(Key.SECURE_ID, int),
|
235
|
+
(Key.SHARE_NUMBER, int),
|
236
|
+
(Key.THRESHOLD, int),
|
237
|
+
(Key.CLIPPING_RANGE, float),
|
238
|
+
(Key.TARGET_RANGE, int),
|
239
|
+
(Key.MOD_RANGE, int),
|
260
240
|
]
|
261
241
|
for key, expected_type in key_type_pairs:
|
262
242
|
if key not in configs:
|
263
243
|
raise KeyError(
|
264
|
-
f"Stage {
|
244
|
+
f"Stage {Stage.SETUP}: the required key '{key}' is "
|
265
245
|
"missing from the input `named_values`."
|
266
246
|
)
|
267
247
|
# Bool is a subclass of int in Python,
|
@@ -269,11 +249,11 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
269
249
|
# pylint: disable-next=unidiomatic-typecheck
|
270
250
|
if type(configs[key]) is not expected_type:
|
271
251
|
raise TypeError(
|
272
|
-
f"Stage {
|
252
|
+
f"Stage {Stage.SETUP}: The value for the key '{key}' "
|
273
253
|
f"must be of type {expected_type}, "
|
274
254
|
f"but got {type(configs[key])} instead."
|
275
255
|
)
|
276
|
-
elif stage ==
|
256
|
+
elif stage == Stage.SHARE_KEYS:
|
277
257
|
for key, value in configs.items():
|
278
258
|
if (
|
279
259
|
not isinstance(value, list)
|
@@ -282,18 +262,18 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
282
262
|
or not isinstance(value[1], bytes)
|
283
263
|
):
|
284
264
|
raise TypeError(
|
285
|
-
f"Stage {
|
265
|
+
f"Stage {Stage.SHARE_KEYS}: "
|
286
266
|
f"the value for the key '{key}' must be a list of two bytes."
|
287
267
|
)
|
288
|
-
elif stage ==
|
268
|
+
elif stage == Stage.COLLECT_MASKED_INPUT:
|
289
269
|
key_type_pairs = [
|
290
|
-
(
|
291
|
-
(
|
270
|
+
(Key.CIPHERTEXT_LIST, bytes),
|
271
|
+
(Key.SOURCE_LIST, int),
|
292
272
|
]
|
293
273
|
for key, expected_type in key_type_pairs:
|
294
274
|
if key not in configs:
|
295
275
|
raise KeyError(
|
296
|
-
f"Stage {
|
276
|
+
f"Stage {Stage.COLLECT_MASKED_INPUT}: "
|
297
277
|
f"the required key '{key}' is "
|
298
278
|
"missing from the input `named_values`."
|
299
279
|
)
|
@@ -304,19 +284,19 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
304
284
|
if type(elm) is not expected_type
|
305
285
|
):
|
306
286
|
raise TypeError(
|
307
|
-
f"Stage {
|
287
|
+
f"Stage {Stage.COLLECT_MASKED_INPUT}: "
|
308
288
|
f"the value for the key '{key}' "
|
309
289
|
f"must be of type List[{expected_type.__name__}]"
|
310
290
|
)
|
311
|
-
elif stage ==
|
291
|
+
elif stage == Stage.UNMASK:
|
312
292
|
key_type_pairs = [
|
313
|
-
(
|
314
|
-
(
|
293
|
+
(Key.ACTIVE_SECURE_ID_LIST, int),
|
294
|
+
(Key.DEAD_SECURE_ID_LIST, int),
|
315
295
|
]
|
316
296
|
for key, expected_type in key_type_pairs:
|
317
297
|
if key not in configs:
|
318
298
|
raise KeyError(
|
319
|
-
f"Stage {
|
299
|
+
f"Stage {Stage.UNMASK}: "
|
320
300
|
f"the required key '{key}' is "
|
321
301
|
"missing from the input `named_values`."
|
322
302
|
)
|
@@ -327,7 +307,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
327
307
|
if type(elm) is not expected_type
|
328
308
|
):
|
329
309
|
raise TypeError(
|
330
|
-
f"Stage {
|
310
|
+
f"Stage {Stage.UNMASK}: "
|
331
311
|
f"the value for the key '{key}' "
|
332
312
|
f"must be of type List[{expected_type.__name__}]"
|
333
313
|
)
|
@@ -340,15 +320,15 @@ def _setup(
|
|
340
320
|
) -> Dict[str, ConfigsRecordValues]:
|
341
321
|
# Assigning parameter values to object fields
|
342
322
|
sec_agg_param_dict = configs
|
343
|
-
state.sample_num = cast(int, sec_agg_param_dict[
|
344
|
-
state.sid = cast(int, sec_agg_param_dict[
|
323
|
+
state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
|
324
|
+
state.sid = cast(int, sec_agg_param_dict[Key.SECURE_ID])
|
345
325
|
log(INFO, "Client %d: starting stage 0...", state.sid)
|
346
326
|
|
347
|
-
state.share_num = cast(int, sec_agg_param_dict[
|
348
|
-
state.threshold = cast(int, sec_agg_param_dict[
|
349
|
-
state.clipping_range = cast(float, sec_agg_param_dict[
|
350
|
-
state.target_range = cast(int, sec_agg_param_dict[
|
351
|
-
state.mod_range = cast(int, sec_agg_param_dict[
|
327
|
+
state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
|
328
|
+
state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
|
329
|
+
state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE])
|
330
|
+
state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE])
|
331
|
+
state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE])
|
352
332
|
|
353
333
|
# Dictionaries containing client secure IDs as keys
|
354
334
|
# and their respective secret shares as values.
|
@@ -367,7 +347,7 @@ def _setup(
|
|
367
347
|
state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
|
368
348
|
state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
|
369
349
|
log(INFO, "Client %d: stage 0 completes. uploading public keys...", state.sid)
|
370
|
-
return {
|
350
|
+
return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
|
371
351
|
|
372
352
|
|
373
353
|
# pylint: disable-next=too-many-locals
|
@@ -429,7 +409,7 @@ def _share_keys(
|
|
429
409
|
ciphertexts.append(ciphertext)
|
430
410
|
|
431
411
|
log(INFO, "Client %d: stage 1 completes. uploading key shares...", state.sid)
|
432
|
-
return {
|
412
|
+
return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
|
433
413
|
|
434
414
|
|
435
415
|
# pylint: disable-next=too-many-locals
|
@@ -440,8 +420,8 @@ def _collect_masked_input(
|
|
440
420
|
) -> Dict[str, ConfigsRecordValues]:
|
441
421
|
log(INFO, "Client %d: starting stage 2...", state.sid)
|
442
422
|
available_clients: List[int] = []
|
443
|
-
ciphertexts = cast(List[bytes], configs[
|
444
|
-
srcs = cast(List[int], configs[
|
423
|
+
ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
|
424
|
+
srcs = cast(List[int], configs[Key.SOURCE_LIST])
|
445
425
|
if len(ciphertexts) + 1 < state.threshold:
|
446
426
|
raise ValueError("Not enough available neighbour clients.")
|
447
427
|
|
@@ -505,7 +485,7 @@ def _collect_masked_input(
|
|
505
485
|
quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
|
506
486
|
log(INFO, "Client %d: stage 2 completes. uploading masked parameters...", state.sid)
|
507
487
|
return {
|
508
|
-
|
488
|
+
Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
|
509
489
|
}
|
510
490
|
|
511
491
|
|
@@ -514,8 +494,8 @@ def _unmask(
|
|
514
494
|
) -> Dict[str, ConfigsRecordValues]:
|
515
495
|
log(INFO, "Client %d: starting stage 3...", state.sid)
|
516
496
|
|
517
|
-
active_sids = cast(List[int], configs[
|
518
|
-
dead_sids = cast(List[int], configs[
|
497
|
+
active_sids = cast(List[int], configs[Key.ACTIVE_SECURE_ID_LIST])
|
498
|
+
dead_sids = cast(List[int], configs[Key.DEAD_SECURE_ID_LIST])
|
519
499
|
# Send private mask seed share for every avaliable client (including itclient)
|
520
500
|
# Send first private key share for building pairwise mask for every dropped client
|
521
501
|
if len(active_sids) < state.threshold:
|
@@ -528,4 +508,4 @@ def _unmask(
|
|
528
508
|
shares += [state.sk1_share_dict[sid] for sid in dead_sids]
|
529
509
|
|
530
510
|
log(INFO, "Client %d: stage 3 completes. uploading key shares...", state.sid)
|
531
|
-
return {
|
511
|
+
return {Key.SECURE_ID_LIST: sids, Key.SHARE_LIST: shares}
|
@@ -15,6 +15,9 @@
|
|
15
15
|
"""Utility functions for differential privacy."""
|
16
16
|
|
17
17
|
|
18
|
+
from logging import WARNING
|
19
|
+
from typing import Optional, Tuple
|
20
|
+
|
18
21
|
import numpy as np
|
19
22
|
|
20
23
|
from flwr.common import (
|
@@ -23,6 +26,7 @@ from flwr.common import (
|
|
23
26
|
ndarrays_to_parameters,
|
24
27
|
parameters_to_ndarrays,
|
25
28
|
)
|
29
|
+
from flwr.common.logger import log
|
26
30
|
|
27
31
|
|
28
32
|
def get_norm(input_arrays: NDArrays) -> float:
|
@@ -72,6 +76,36 @@ def compute_clip_model_update(
|
|
72
76
|
param1[i] = param2[i] + model_update[i]
|
73
77
|
|
74
78
|
|
79
|
+
def adaptive_clip_inputs_inplace(input_arrays: NDArrays, clipping_norm: float) -> bool:
|
80
|
+
"""Clip model update based on the clipping norm in-place.
|
81
|
+
|
82
|
+
It returns true if scaling_factor < 1 which is used for norm_bit
|
83
|
+
FlatClip method of the paper: https://arxiv.org/abs/1710.06963
|
84
|
+
"""
|
85
|
+
input_norm = get_norm(input_arrays)
|
86
|
+
scaling_factor = min(1, clipping_norm / input_norm)
|
87
|
+
for array in input_arrays:
|
88
|
+
array *= scaling_factor
|
89
|
+
return scaling_factor < 1
|
90
|
+
|
91
|
+
|
92
|
+
def compute_adaptive_clip_model_update(
|
93
|
+
param1: NDArrays, param2: NDArrays, clipping_norm: float
|
94
|
+
) -> bool:
|
95
|
+
"""Compute model update, clip it, then add the clipped value to param1.
|
96
|
+
|
97
|
+
model update = param1 - param2
|
98
|
+
Return the norm_bit
|
99
|
+
"""
|
100
|
+
model_update = [np.subtract(x, y) for (x, y) in zip(param1, param2)]
|
101
|
+
norm_bit = adaptive_clip_inputs_inplace(model_update, clipping_norm)
|
102
|
+
|
103
|
+
for i, _ in enumerate(param2):
|
104
|
+
param1[i] = param2[i] + model_update[i]
|
105
|
+
|
106
|
+
return norm_bit
|
107
|
+
|
108
|
+
|
75
109
|
def add_gaussian_noise_to_params(
|
76
110
|
model_params: Parameters,
|
77
111
|
noise_multiplier: float,
|
@@ -85,3 +119,46 @@ def add_gaussian_noise_to_params(
|
|
85
119
|
compute_stdv(noise_multiplier, clipping_norm, num_sampled_clients),
|
86
120
|
)
|
87
121
|
return ndarrays_to_parameters(model_params_ndarrays)
|
122
|
+
|
123
|
+
|
124
|
+
def compute_adaptive_noise_params(
|
125
|
+
noise_multiplier: float,
|
126
|
+
num_sampled_clients: float,
|
127
|
+
clipped_count_stddev: Optional[float],
|
128
|
+
) -> Tuple[float, float]:
|
129
|
+
"""Compute noising parameters for the adaptive clipping.
|
130
|
+
|
131
|
+
Paper: https://arxiv.org/abs/1905.03871
|
132
|
+
"""
|
133
|
+
if noise_multiplier > 0:
|
134
|
+
if clipped_count_stddev is None:
|
135
|
+
clipped_count_stddev = num_sampled_clients / 20
|
136
|
+
if noise_multiplier >= 2 * clipped_count_stddev:
|
137
|
+
raise ValueError(
|
138
|
+
f"If not specified, `clipped_count_stddev` is set to "
|
139
|
+
f"`num_sampled_clients`/20 by default. This value "
|
140
|
+
f"({num_sampled_clients / 20}) is too low to achieve the "
|
141
|
+
f"desired effective `noise_multiplier` ({noise_multiplier}). "
|
142
|
+
f"Consider increasing `clipped_count_stddev` or decreasing "
|
143
|
+
f"`noise_multiplier`."
|
144
|
+
)
|
145
|
+
noise_multiplier_value = (
|
146
|
+
noise_multiplier ** (-2) - (2 * clipped_count_stddev) ** (-2)
|
147
|
+
) ** -0.5
|
148
|
+
|
149
|
+
adding_noise = noise_multiplier_value / noise_multiplier
|
150
|
+
if adding_noise >= 2:
|
151
|
+
log(
|
152
|
+
WARNING,
|
153
|
+
"A significant amount of noise (%s) has to be "
|
154
|
+
"added. Consider increasing `clipped_count_stddev` or "
|
155
|
+
"`num_sampled_clients`.",
|
156
|
+
adding_noise,
|
157
|
+
)
|
158
|
+
|
159
|
+
else:
|
160
|
+
if clipped_count_stddev is None:
|
161
|
+
clipped_count_stddev = 0.0
|
162
|
+
noise_multiplier_value = 0.0
|
163
|
+
|
164
|
+
return clipped_count_stddev, noise_multiplier_value
|
@@ -14,33 +14,55 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Constants for the SecAgg/SecAgg+ protocol."""
|
16
16
|
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
17
20
|
RECORD_KEY_STATE = "secaggplus_state"
|
18
21
|
RECORD_KEY_CONFIGS = "secaggplus_configs"
|
19
22
|
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
23
|
+
|
24
|
+
class Stage:
|
25
|
+
"""Stages for the SecAgg+ protocol."""
|
26
|
+
|
27
|
+
SETUP = "setup"
|
28
|
+
SHARE_KEYS = "share_keys"
|
29
|
+
COLLECT_MASKED_INPUT = "collect_masked_input"
|
30
|
+
UNMASK = "unmask"
|
31
|
+
_stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_INPUT, UNMASK)
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def all(cls) -> tuple[str, str, str, str]:
|
35
|
+
"""Return all stages."""
|
36
|
+
return cls._stages
|
37
|
+
|
38
|
+
def __new__(cls) -> Stage:
|
39
|
+
"""Prevent instantiation."""
|
40
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
41
|
+
|
42
|
+
|
43
|
+
class Key:
|
44
|
+
"""Keys for the configs in the ConfigsRecord."""
|
45
|
+
|
46
|
+
STAGE = "stage"
|
47
|
+
SAMPLE_NUMBER = "sample_num"
|
48
|
+
SECURE_ID = "secure_id"
|
49
|
+
SHARE_NUMBER = "share_num"
|
50
|
+
THRESHOLD = "threshold"
|
51
|
+
CLIPPING_RANGE = "clipping_range"
|
52
|
+
TARGET_RANGE = "target_range"
|
53
|
+
MOD_RANGE = "mod_range"
|
54
|
+
PUBLIC_KEY_1 = "pk1"
|
55
|
+
PUBLIC_KEY_2 = "pk2"
|
56
|
+
DESTINATION_LIST = "dsts"
|
57
|
+
CIPHERTEXT_LIST = "ctxts"
|
58
|
+
SOURCE_LIST = "srcs"
|
59
|
+
PARAMETERS = "params"
|
60
|
+
MASKED_PARAMETERS = "masked_params"
|
61
|
+
ACTIVE_SECURE_ID_LIST = "active_sids"
|
62
|
+
DEAD_SECURE_ID_LIST = "dead_sids"
|
63
|
+
SECURE_ID_LIST = "sids"
|
64
|
+
SHARE_LIST = "shares"
|
65
|
+
|
66
|
+
def __new__(cls) -> Key:
|
67
|
+
"""Prevent instantiation."""
|
68
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
flwr/proto/error_pb2.py
ADDED
@@ -0,0 +1,26 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3
|
+
# source: flwr/proto/error.proto
|
4
|
+
# Protobuf Python Version: 4.25.0
|
5
|
+
"""Generated protocol buffer code."""
|
6
|
+
from google.protobuf import descriptor as _descriptor
|
7
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
8
|
+
from google.protobuf import symbol_database as _symbol_database
|
9
|
+
from google.protobuf.internal import builder as _builder
|
10
|
+
# @@protoc_insertion_point(imports)
|
11
|
+
|
12
|
+
_sym_db = _symbol_database.Default()
|
13
|
+
|
14
|
+
|
15
|
+
|
16
|
+
|
17
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/error.proto\x12\nflwr.proto\"%\n\x05\x45rror\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x12\x12\x0e\n\x06reason\x18\x02 \x01(\tb\x06proto3')
|
18
|
+
|
19
|
+
_globals = globals()
|
20
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
21
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.error_pb2', _globals)
|
22
|
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
23
|
+
DESCRIPTOR._options = None
|
24
|
+
_globals['_ERROR']._serialized_start=38
|
25
|
+
_globals['_ERROR']._serialized_end=75
|
26
|
+
# @@protoc_insertion_point(module_scope)
|
flwr/proto/error_pb2.pyi
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
"""
|
2
|
+
@generated by mypy-protobuf. Do not edit manually!
|
3
|
+
isort:skip_file
|
4
|
+
"""
|
5
|
+
import builtins
|
6
|
+
import google.protobuf.descriptor
|
7
|
+
import google.protobuf.message
|
8
|
+
import typing
|
9
|
+
import typing_extensions
|
10
|
+
|
11
|
+
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
12
|
+
|
13
|
+
class Error(google.protobuf.message.Message):
|
14
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
15
|
+
CODE_FIELD_NUMBER: builtins.int
|
16
|
+
REASON_FIELD_NUMBER: builtins.int
|
17
|
+
code: builtins.int
|
18
|
+
reason: typing.Text
|
19
|
+
def __init__(self,
|
20
|
+
*,
|
21
|
+
code: builtins.int = ...,
|
22
|
+
reason: typing.Text = ...,
|
23
|
+
) -> None: ...
|
24
|
+
def ClearField(self, field_name: typing_extensions.Literal["code",b"code","reason",b"reason"]) -> None: ...
|
25
|
+
global___Error = Error
|
flwr/proto/task_pb2.py
CHANGED
@@ -15,19 +15,20 @@ _sym_db = _symbol_database.Default()
|
|
15
15
|
from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
16
16
|
from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
|
17
17
|
from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
|
18
|
+
from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2
|
18
19
|
|
19
20
|
|
20
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\
|
21
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
|
21
22
|
|
22
23
|
_globals = globals()
|
23
24
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
24
25
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals)
|
25
26
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
26
27
|
DESCRIPTOR._options = None
|
27
|
-
_globals['_TASK']._serialized_start=
|
28
|
-
_globals['_TASK']._serialized_end=
|
29
|
-
_globals['_TASKINS']._serialized_start=
|
30
|
-
_globals['_TASKINS']._serialized_end=
|
31
|
-
_globals['_TASKRES']._serialized_start=
|
32
|
-
_globals['_TASKRES']._serialized_end=
|
28
|
+
_globals['_TASK']._serialized_start=141
|
29
|
+
_globals['_TASK']._serialized_end=387
|
30
|
+
_globals['_TASKINS']._serialized_start=389
|
31
|
+
_globals['_TASKINS']._serialized_end=481
|
32
|
+
_globals['_TASKRES']._serialized_start=483
|
33
|
+
_globals['_TASKRES']._serialized_end=575
|
33
34
|
# @@protoc_insertion_point(module_scope)
|