flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/client/mod/__init__.py +3 -2
- flwr/client/mod/centraldp_mods.py +63 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
- flwr/common/differential_privacy.py +77 -0
- flwr/common/differential_privacy_constants.py +1 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
- flwr/proto/error_pb2.py +26 -0
- flwr/proto/error_pb2.pyi +25 -0
- flwr/proto/error_pb2_grpc.py +4 -0
- flwr/proto/error_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +8 -7
- flwr/proto/task_pb2.pyi +7 -2
- flwr/server/__init__.py +4 -0
- flwr/server/app.py +8 -31
- flwr/server/client_proxy.py +5 -0
- flwr/server/compat/__init__.py +2 -0
- flwr/server/compat/app.py +7 -88
- flwr/server/compat/app_utils.py +102 -0
- flwr/server/compat/driver_client_proxy.py +22 -10
- flwr/server/compat/legacy_context.py +55 -0
- flwr/server/run_serverapp.py +1 -1
- flwr/server/server.py +18 -8
- flwr/server/strategy/__init__.py +24 -14
- flwr/server/strategy/dp_adaptive_clipping.py +449 -0
- flwr/server/strategy/dp_fixed_clipping.py +5 -7
- flwr/server/superlink/driver/driver_grpc.py +54 -0
- flwr/server/superlink/driver/driver_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
- flwr/server/superlink/fleet/vce/vce_api.py +236 -16
- flwr/server/typing.py +1 -0
- flwr/server/workflow/__init__.py +22 -0
- flwr/server/workflow/default_workflows.py +357 -0
- flwr/simulation/__init__.py +3 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
- flwr/simulation/run_simulation.py +177 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
flwr/client/mod/__init__.py
CHANGED
@@ -15,12 +15,13 @@
|
|
15
15
|
"""Mods."""
|
16
16
|
|
17
17
|
|
18
|
-
from .centraldp_mods import fixedclipping_mod
|
18
|
+
from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
|
19
19
|
from .secure_aggregation.secaggplus_mod import secaggplus_mod
|
20
20
|
from .utils import make_ffn
|
21
21
|
|
22
22
|
__all__ = [
|
23
|
+
"adaptiveclipping_mod",
|
24
|
+
"fixedclipping_mod",
|
23
25
|
"make_ffn",
|
24
26
|
"secaggplus_mod",
|
25
|
-
"fixedclipping_mod",
|
26
27
|
]
|
@@ -20,8 +20,11 @@ from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
|
|
20
20
|
from flwr.common import recordset_compat as compat
|
21
21
|
from flwr.common.constant import MESSAGE_TYPE_FIT
|
22
22
|
from flwr.common.context import Context
|
23
|
-
from flwr.common.differential_privacy import
|
24
|
-
|
23
|
+
from flwr.common.differential_privacy import (
|
24
|
+
compute_adaptive_clip_model_update,
|
25
|
+
compute_clip_model_update,
|
26
|
+
)
|
27
|
+
from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM, KEY_NORM_BIT
|
25
28
|
from flwr.common.message import Message
|
26
29
|
|
27
30
|
|
@@ -74,3 +77,61 @@ def fixedclipping_mod(
|
|
74
77
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
75
78
|
out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
|
76
79
|
return out_msg
|
80
|
+
|
81
|
+
|
82
|
+
def adaptiveclipping_mod(
|
83
|
+
msg: Message, ctxt: Context, call_next: ClientAppCallable
|
84
|
+
) -> Message:
|
85
|
+
"""Client-side adaptive clipping modifier.
|
86
|
+
|
87
|
+
This mod needs to be used with the DifferentialPrivacyClientSideAdaptiveClipping
|
88
|
+
server-side strategy wrapper.
|
89
|
+
|
90
|
+
The wrapper sends the clipping_norm value to the client.
|
91
|
+
|
92
|
+
This mod clips the client model updates before sending them to the server.
|
93
|
+
|
94
|
+
It also sends KEY_NORM_BIT to the server for computing the new clipping value.
|
95
|
+
|
96
|
+
It operates on messages with type MESSAGE_TYPE_FIT.
|
97
|
+
|
98
|
+
Notes
|
99
|
+
-----
|
100
|
+
Consider the order of mods when using multiple.
|
101
|
+
|
102
|
+
Typically, adaptiveclipping_mod should be the last to operate on params.
|
103
|
+
"""
|
104
|
+
if msg.metadata.message_type != MESSAGE_TYPE_FIT:
|
105
|
+
return call_next(msg, ctxt)
|
106
|
+
|
107
|
+
fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
|
108
|
+
|
109
|
+
if KEY_CLIPPING_NORM not in fit_ins.config:
|
110
|
+
raise KeyError(
|
111
|
+
f"The {KEY_CLIPPING_NORM} value is not supplied by the "
|
112
|
+
f"DifferentialPrivacyClientSideFixedClipping wrapper at"
|
113
|
+
f" the server side."
|
114
|
+
)
|
115
|
+
if not isinstance(fit_ins.config[KEY_CLIPPING_NORM], float):
|
116
|
+
raise ValueError(f"{KEY_CLIPPING_NORM} should be a float value.")
|
117
|
+
clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM])
|
118
|
+
server_to_client_params = parameters_to_ndarrays(fit_ins.parameters)
|
119
|
+
|
120
|
+
# Call inner app
|
121
|
+
out_msg = call_next(msg, ctxt)
|
122
|
+
fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True)
|
123
|
+
|
124
|
+
client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
|
125
|
+
|
126
|
+
# Clip the client update
|
127
|
+
norm_bit = compute_adaptive_clip_model_update(
|
128
|
+
client_to_server_params,
|
129
|
+
server_to_client_params,
|
130
|
+
clipping_norm,
|
131
|
+
)
|
132
|
+
|
133
|
+
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
134
|
+
|
135
|
+
fit_res.metrics[KEY_NORM_BIT] = norm_bit
|
136
|
+
out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
|
137
|
+
return out_msg
|
@@ -52,31 +52,10 @@ from flwr.common.secure_aggregation.ndarrays_arithmetic import (
|
|
52
52
|
)
|
53
53
|
from flwr.common.secure_aggregation.quantization import quantize
|
54
54
|
from flwr.common.secure_aggregation.secaggplus_constants import (
|
55
|
-
KEY_ACTIVE_SECURE_ID_LIST,
|
56
|
-
KEY_CIPHERTEXT_LIST,
|
57
|
-
KEY_CLIPPING_RANGE,
|
58
|
-
KEY_DEAD_SECURE_ID_LIST,
|
59
|
-
KEY_DESTINATION_LIST,
|
60
|
-
KEY_MASKED_PARAMETERS,
|
61
|
-
KEY_MOD_RANGE,
|
62
|
-
KEY_PUBLIC_KEY_1,
|
63
|
-
KEY_PUBLIC_KEY_2,
|
64
|
-
KEY_SAMPLE_NUMBER,
|
65
|
-
KEY_SECURE_ID,
|
66
|
-
KEY_SECURE_ID_LIST,
|
67
|
-
KEY_SHARE_LIST,
|
68
|
-
KEY_SHARE_NUMBER,
|
69
|
-
KEY_SOURCE_LIST,
|
70
|
-
KEY_STAGE,
|
71
|
-
KEY_TARGET_RANGE,
|
72
|
-
KEY_THRESHOLD,
|
73
55
|
RECORD_KEY_CONFIGS,
|
74
56
|
RECORD_KEY_STATE,
|
75
|
-
|
76
|
-
|
77
|
-
STAGE_SHARE_KEYS,
|
78
|
-
STAGE_UNMASK,
|
79
|
-
STAGES,
|
57
|
+
Key,
|
58
|
+
Stage,
|
80
59
|
)
|
81
60
|
from flwr.common.secure_aggregation.secaggplus_utils import (
|
82
61
|
pseudo_rand_gen,
|
@@ -91,7 +70,7 @@ from flwr.common.typing import ConfigsRecordValues, FitRes
|
|
91
70
|
class SecAggPlusState:
|
92
71
|
"""State of the SecAgg+ protocol."""
|
93
72
|
|
94
|
-
current_stage: str =
|
73
|
+
current_stage: str = Stage.UNMASK
|
95
74
|
|
96
75
|
sid: int = 0
|
97
76
|
sample_num: int = 0
|
@@ -187,20 +166,20 @@ def secaggplus_mod(
|
|
187
166
|
check_stage(state.current_stage, configs)
|
188
167
|
|
189
168
|
# Update the current stage
|
190
|
-
state.current_stage = cast(str, configs.pop(
|
169
|
+
state.current_stage = cast(str, configs.pop(Key.STAGE))
|
191
170
|
|
192
171
|
# Check the validity of the configs based on the current stage
|
193
172
|
check_configs(state.current_stage, configs)
|
194
173
|
|
195
174
|
# Execute
|
196
|
-
if state.current_stage ==
|
175
|
+
if state.current_stage == Stage.SETUP:
|
197
176
|
res = _setup(state, configs)
|
198
|
-
elif state.current_stage ==
|
177
|
+
elif state.current_stage == Stage.SHARE_KEYS:
|
199
178
|
res = _share_keys(state, configs)
|
200
|
-
elif state.current_stage ==
|
179
|
+
elif state.current_stage == Stage.COLLECT_MASKED_INPUT:
|
201
180
|
fit = _get_fit_fn(msg, ctxt, call_next)
|
202
181
|
res = _collect_masked_input(state, configs, fit)
|
203
|
-
elif state.current_stage ==
|
182
|
+
elif state.current_stage == Stage.UNMASK:
|
204
183
|
res = _unmask(state, configs)
|
205
184
|
else:
|
206
185
|
raise ValueError(f"Unknown secagg stage: {state.current_stage}")
|
@@ -215,28 +194,29 @@ def secaggplus_mod(
|
|
215
194
|
|
216
195
|
def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
217
196
|
"""Check the validity of the next stage."""
|
218
|
-
# Check the existence of
|
219
|
-
if
|
197
|
+
# Check the existence of Config.STAGE
|
198
|
+
if Key.STAGE not in configs:
|
220
199
|
raise KeyError(
|
221
|
-
f"The required key '{
|
200
|
+
f"The required key '{Key.STAGE}' is missing from the input `named_values`."
|
222
201
|
)
|
223
202
|
|
224
|
-
# Check the value type of the
|
225
|
-
next_stage = configs[
|
203
|
+
# Check the value type of the Config.STAGE
|
204
|
+
next_stage = configs[Key.STAGE]
|
226
205
|
if not isinstance(next_stage, str):
|
227
206
|
raise TypeError(
|
228
|
-
f"The value for the key '{
|
207
|
+
f"The value for the key '{Key.STAGE}' must be of type {str}, "
|
229
208
|
f"but got {type(next_stage)} instead."
|
230
209
|
)
|
231
210
|
|
232
211
|
# Check the validity of the next stage
|
233
|
-
if next_stage ==
|
234
|
-
if current_stage !=
|
212
|
+
if next_stage == Stage.SETUP:
|
213
|
+
if current_stage != Stage.UNMASK:
|
235
214
|
log(WARNING, "Restart from the setup stage")
|
236
215
|
# If stage is not "setup",
|
237
216
|
# the stage from `named_values` should be the expected next stage
|
238
217
|
else:
|
239
|
-
|
218
|
+
stages = Stage.all()
|
219
|
+
expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
|
240
220
|
if next_stage != expected_next_stage:
|
241
221
|
raise ValueError(
|
242
222
|
"Abort secure aggregation: "
|
@@ -248,20 +228,20 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
248
228
|
def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
249
229
|
"""Check the validity of the configs."""
|
250
230
|
# Check `named_values` for the setup stage
|
251
|
-
if stage ==
|
231
|
+
if stage == Stage.SETUP:
|
252
232
|
key_type_pairs = [
|
253
|
-
(
|
254
|
-
(
|
255
|
-
(
|
256
|
-
(
|
257
|
-
(
|
258
|
-
(
|
259
|
-
(
|
233
|
+
(Key.SAMPLE_NUMBER, int),
|
234
|
+
(Key.SECURE_ID, int),
|
235
|
+
(Key.SHARE_NUMBER, int),
|
236
|
+
(Key.THRESHOLD, int),
|
237
|
+
(Key.CLIPPING_RANGE, float),
|
238
|
+
(Key.TARGET_RANGE, int),
|
239
|
+
(Key.MOD_RANGE, int),
|
260
240
|
]
|
261
241
|
for key, expected_type in key_type_pairs:
|
262
242
|
if key not in configs:
|
263
243
|
raise KeyError(
|
264
|
-
f"Stage {
|
244
|
+
f"Stage {Stage.SETUP}: the required key '{key}' is "
|
265
245
|
"missing from the input `named_values`."
|
266
246
|
)
|
267
247
|
# Bool is a subclass of int in Python,
|
@@ -269,11 +249,11 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
269
249
|
# pylint: disable-next=unidiomatic-typecheck
|
270
250
|
if type(configs[key]) is not expected_type:
|
271
251
|
raise TypeError(
|
272
|
-
f"Stage {
|
252
|
+
f"Stage {Stage.SETUP}: The value for the key '{key}' "
|
273
253
|
f"must be of type {expected_type}, "
|
274
254
|
f"but got {type(configs[key])} instead."
|
275
255
|
)
|
276
|
-
elif stage ==
|
256
|
+
elif stage == Stage.SHARE_KEYS:
|
277
257
|
for key, value in configs.items():
|
278
258
|
if (
|
279
259
|
not isinstance(value, list)
|
@@ -282,18 +262,18 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
282
262
|
or not isinstance(value[1], bytes)
|
283
263
|
):
|
284
264
|
raise TypeError(
|
285
|
-
f"Stage {
|
265
|
+
f"Stage {Stage.SHARE_KEYS}: "
|
286
266
|
f"the value for the key '{key}' must be a list of two bytes."
|
287
267
|
)
|
288
|
-
elif stage ==
|
268
|
+
elif stage == Stage.COLLECT_MASKED_INPUT:
|
289
269
|
key_type_pairs = [
|
290
|
-
(
|
291
|
-
(
|
270
|
+
(Key.CIPHERTEXT_LIST, bytes),
|
271
|
+
(Key.SOURCE_LIST, int),
|
292
272
|
]
|
293
273
|
for key, expected_type in key_type_pairs:
|
294
274
|
if key not in configs:
|
295
275
|
raise KeyError(
|
296
|
-
f"Stage {
|
276
|
+
f"Stage {Stage.COLLECT_MASKED_INPUT}: "
|
297
277
|
f"the required key '{key}' is "
|
298
278
|
"missing from the input `named_values`."
|
299
279
|
)
|
@@ -304,19 +284,19 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
304
284
|
if type(elm) is not expected_type
|
305
285
|
):
|
306
286
|
raise TypeError(
|
307
|
-
f"Stage {
|
287
|
+
f"Stage {Stage.COLLECT_MASKED_INPUT}: "
|
308
288
|
f"the value for the key '{key}' "
|
309
289
|
f"must be of type List[{expected_type.__name__}]"
|
310
290
|
)
|
311
|
-
elif stage ==
|
291
|
+
elif stage == Stage.UNMASK:
|
312
292
|
key_type_pairs = [
|
313
|
-
(
|
314
|
-
(
|
293
|
+
(Key.ACTIVE_SECURE_ID_LIST, int),
|
294
|
+
(Key.DEAD_SECURE_ID_LIST, int),
|
315
295
|
]
|
316
296
|
for key, expected_type in key_type_pairs:
|
317
297
|
if key not in configs:
|
318
298
|
raise KeyError(
|
319
|
-
f"Stage {
|
299
|
+
f"Stage {Stage.UNMASK}: "
|
320
300
|
f"the required key '{key}' is "
|
321
301
|
"missing from the input `named_values`."
|
322
302
|
)
|
@@ -327,7 +307,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
327
307
|
if type(elm) is not expected_type
|
328
308
|
):
|
329
309
|
raise TypeError(
|
330
|
-
f"Stage {
|
310
|
+
f"Stage {Stage.UNMASK}: "
|
331
311
|
f"the value for the key '{key}' "
|
332
312
|
f"must be of type List[{expected_type.__name__}]"
|
333
313
|
)
|
@@ -340,15 +320,15 @@ def _setup(
|
|
340
320
|
) -> Dict[str, ConfigsRecordValues]:
|
341
321
|
# Assigning parameter values to object fields
|
342
322
|
sec_agg_param_dict = configs
|
343
|
-
state.sample_num = cast(int, sec_agg_param_dict[
|
344
|
-
state.sid = cast(int, sec_agg_param_dict[
|
323
|
+
state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
|
324
|
+
state.sid = cast(int, sec_agg_param_dict[Key.SECURE_ID])
|
345
325
|
log(INFO, "Client %d: starting stage 0...", state.sid)
|
346
326
|
|
347
|
-
state.share_num = cast(int, sec_agg_param_dict[
|
348
|
-
state.threshold = cast(int, sec_agg_param_dict[
|
349
|
-
state.clipping_range = cast(float, sec_agg_param_dict[
|
350
|
-
state.target_range = cast(int, sec_agg_param_dict[
|
351
|
-
state.mod_range = cast(int, sec_agg_param_dict[
|
327
|
+
state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
|
328
|
+
state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
|
329
|
+
state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE])
|
330
|
+
state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE])
|
331
|
+
state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE])
|
352
332
|
|
353
333
|
# Dictionaries containing client secure IDs as keys
|
354
334
|
# and their respective secret shares as values.
|
@@ -367,7 +347,7 @@ def _setup(
|
|
367
347
|
state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
|
368
348
|
state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
|
369
349
|
log(INFO, "Client %d: stage 0 completes. uploading public keys...", state.sid)
|
370
|
-
return {
|
350
|
+
return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
|
371
351
|
|
372
352
|
|
373
353
|
# pylint: disable-next=too-many-locals
|
@@ -429,7 +409,7 @@ def _share_keys(
|
|
429
409
|
ciphertexts.append(ciphertext)
|
430
410
|
|
431
411
|
log(INFO, "Client %d: stage 1 completes. uploading key shares...", state.sid)
|
432
|
-
return {
|
412
|
+
return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
|
433
413
|
|
434
414
|
|
435
415
|
# pylint: disable-next=too-many-locals
|
@@ -440,8 +420,8 @@ def _collect_masked_input(
|
|
440
420
|
) -> Dict[str, ConfigsRecordValues]:
|
441
421
|
log(INFO, "Client %d: starting stage 2...", state.sid)
|
442
422
|
available_clients: List[int] = []
|
443
|
-
ciphertexts = cast(List[bytes], configs[
|
444
|
-
srcs = cast(List[int], configs[
|
423
|
+
ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
|
424
|
+
srcs = cast(List[int], configs[Key.SOURCE_LIST])
|
445
425
|
if len(ciphertexts) + 1 < state.threshold:
|
446
426
|
raise ValueError("Not enough available neighbour clients.")
|
447
427
|
|
@@ -505,7 +485,7 @@ def _collect_masked_input(
|
|
505
485
|
quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
|
506
486
|
log(INFO, "Client %d: stage 2 completes. uploading masked parameters...", state.sid)
|
507
487
|
return {
|
508
|
-
|
488
|
+
Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
|
509
489
|
}
|
510
490
|
|
511
491
|
|
@@ -514,8 +494,8 @@ def _unmask(
|
|
514
494
|
) -> Dict[str, ConfigsRecordValues]:
|
515
495
|
log(INFO, "Client %d: starting stage 3...", state.sid)
|
516
496
|
|
517
|
-
active_sids = cast(List[int], configs[
|
518
|
-
dead_sids = cast(List[int], configs[
|
497
|
+
active_sids = cast(List[int], configs[Key.ACTIVE_SECURE_ID_LIST])
|
498
|
+
dead_sids = cast(List[int], configs[Key.DEAD_SECURE_ID_LIST])
|
519
499
|
# Send private mask seed share for every avaliable client (including itclient)
|
520
500
|
# Send first private key share for building pairwise mask for every dropped client
|
521
501
|
if len(active_sids) < state.threshold:
|
@@ -528,4 +508,4 @@ def _unmask(
|
|
528
508
|
shares += [state.sk1_share_dict[sid] for sid in dead_sids]
|
529
509
|
|
530
510
|
log(INFO, "Client %d: stage 3 completes. uploading key shares...", state.sid)
|
531
|
-
return {
|
511
|
+
return {Key.SECURE_ID_LIST: sids, Key.SHARE_LIST: shares}
|
@@ -15,6 +15,9 @@
|
|
15
15
|
"""Utility functions for differential privacy."""
|
16
16
|
|
17
17
|
|
18
|
+
from logging import WARNING
|
19
|
+
from typing import Optional, Tuple
|
20
|
+
|
18
21
|
import numpy as np
|
19
22
|
|
20
23
|
from flwr.common import (
|
@@ -23,6 +26,7 @@ from flwr.common import (
|
|
23
26
|
ndarrays_to_parameters,
|
24
27
|
parameters_to_ndarrays,
|
25
28
|
)
|
29
|
+
from flwr.common.logger import log
|
26
30
|
|
27
31
|
|
28
32
|
def get_norm(input_arrays: NDArrays) -> float:
|
@@ -72,6 +76,36 @@ def compute_clip_model_update(
|
|
72
76
|
param1[i] = param2[i] + model_update[i]
|
73
77
|
|
74
78
|
|
79
|
+
def adaptive_clip_inputs_inplace(input_arrays: NDArrays, clipping_norm: float) -> bool:
|
80
|
+
"""Clip model update based on the clipping norm in-place.
|
81
|
+
|
82
|
+
It returns true if scaling_factor < 1 which is used for norm_bit
|
83
|
+
FlatClip method of the paper: https://arxiv.org/abs/1710.06963
|
84
|
+
"""
|
85
|
+
input_norm = get_norm(input_arrays)
|
86
|
+
scaling_factor = min(1, clipping_norm / input_norm)
|
87
|
+
for array in input_arrays:
|
88
|
+
array *= scaling_factor
|
89
|
+
return scaling_factor < 1
|
90
|
+
|
91
|
+
|
92
|
+
def compute_adaptive_clip_model_update(
|
93
|
+
param1: NDArrays, param2: NDArrays, clipping_norm: float
|
94
|
+
) -> bool:
|
95
|
+
"""Compute model update, clip it, then add the clipped value to param1.
|
96
|
+
|
97
|
+
model update = param1 - param2
|
98
|
+
Return the norm_bit
|
99
|
+
"""
|
100
|
+
model_update = [np.subtract(x, y) for (x, y) in zip(param1, param2)]
|
101
|
+
norm_bit = adaptive_clip_inputs_inplace(model_update, clipping_norm)
|
102
|
+
|
103
|
+
for i, _ in enumerate(param2):
|
104
|
+
param1[i] = param2[i] + model_update[i]
|
105
|
+
|
106
|
+
return norm_bit
|
107
|
+
|
108
|
+
|
75
109
|
def add_gaussian_noise_to_params(
|
76
110
|
model_params: Parameters,
|
77
111
|
noise_multiplier: float,
|
@@ -85,3 +119,46 @@ def add_gaussian_noise_to_params(
|
|
85
119
|
compute_stdv(noise_multiplier, clipping_norm, num_sampled_clients),
|
86
120
|
)
|
87
121
|
return ndarrays_to_parameters(model_params_ndarrays)
|
122
|
+
|
123
|
+
|
124
|
+
def compute_adaptive_noise_params(
|
125
|
+
noise_multiplier: float,
|
126
|
+
num_sampled_clients: float,
|
127
|
+
clipped_count_stddev: Optional[float],
|
128
|
+
) -> Tuple[float, float]:
|
129
|
+
"""Compute noising parameters for the adaptive clipping.
|
130
|
+
|
131
|
+
Paper: https://arxiv.org/abs/1905.03871
|
132
|
+
"""
|
133
|
+
if noise_multiplier > 0:
|
134
|
+
if clipped_count_stddev is None:
|
135
|
+
clipped_count_stddev = num_sampled_clients / 20
|
136
|
+
if noise_multiplier >= 2 * clipped_count_stddev:
|
137
|
+
raise ValueError(
|
138
|
+
f"If not specified, `clipped_count_stddev` is set to "
|
139
|
+
f"`num_sampled_clients`/20 by default. This value "
|
140
|
+
f"({num_sampled_clients / 20}) is too low to achieve the "
|
141
|
+
f"desired effective `noise_multiplier` ({noise_multiplier}). "
|
142
|
+
f"Consider increasing `clipped_count_stddev` or decreasing "
|
143
|
+
f"`noise_multiplier`."
|
144
|
+
)
|
145
|
+
noise_multiplier_value = (
|
146
|
+
noise_multiplier ** (-2) - (2 * clipped_count_stddev) ** (-2)
|
147
|
+
) ** -0.5
|
148
|
+
|
149
|
+
adding_noise = noise_multiplier_value / noise_multiplier
|
150
|
+
if adding_noise >= 2:
|
151
|
+
log(
|
152
|
+
WARNING,
|
153
|
+
"A significant amount of noise (%s) has to be "
|
154
|
+
"added. Consider increasing `clipped_count_stddev` or "
|
155
|
+
"`num_sampled_clients`.",
|
156
|
+
adding_noise,
|
157
|
+
)
|
158
|
+
|
159
|
+
else:
|
160
|
+
if clipped_count_stddev is None:
|
161
|
+
clipped_count_stddev = 0.0
|
162
|
+
noise_multiplier_value = 0.0
|
163
|
+
|
164
|
+
return clipped_count_stddev, noise_multiplier_value
|
@@ -14,33 +14,55 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Constants for the SecAgg/SecAgg+ protocol."""
|
16
16
|
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
17
20
|
RECORD_KEY_STATE = "secaggplus_state"
|
18
21
|
RECORD_KEY_CONFIGS = "secaggplus_configs"
|
19
22
|
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
23
|
+
|
24
|
+
class Stage:
|
25
|
+
"""Stages for the SecAgg+ protocol."""
|
26
|
+
|
27
|
+
SETUP = "setup"
|
28
|
+
SHARE_KEYS = "share_keys"
|
29
|
+
COLLECT_MASKED_INPUT = "collect_masked_input"
|
30
|
+
UNMASK = "unmask"
|
31
|
+
_stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_INPUT, UNMASK)
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def all(cls) -> tuple[str, str, str, str]:
|
35
|
+
"""Return all stages."""
|
36
|
+
return cls._stages
|
37
|
+
|
38
|
+
def __new__(cls) -> Stage:
|
39
|
+
"""Prevent instantiation."""
|
40
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
41
|
+
|
42
|
+
|
43
|
+
class Key:
|
44
|
+
"""Keys for the configs in the ConfigsRecord."""
|
45
|
+
|
46
|
+
STAGE = "stage"
|
47
|
+
SAMPLE_NUMBER = "sample_num"
|
48
|
+
SECURE_ID = "secure_id"
|
49
|
+
SHARE_NUMBER = "share_num"
|
50
|
+
THRESHOLD = "threshold"
|
51
|
+
CLIPPING_RANGE = "clipping_range"
|
52
|
+
TARGET_RANGE = "target_range"
|
53
|
+
MOD_RANGE = "mod_range"
|
54
|
+
PUBLIC_KEY_1 = "pk1"
|
55
|
+
PUBLIC_KEY_2 = "pk2"
|
56
|
+
DESTINATION_LIST = "dsts"
|
57
|
+
CIPHERTEXT_LIST = "ctxts"
|
58
|
+
SOURCE_LIST = "srcs"
|
59
|
+
PARAMETERS = "params"
|
60
|
+
MASKED_PARAMETERS = "masked_params"
|
61
|
+
ACTIVE_SECURE_ID_LIST = "active_sids"
|
62
|
+
DEAD_SECURE_ID_LIST = "dead_sids"
|
63
|
+
SECURE_ID_LIST = "sids"
|
64
|
+
SHARE_LIST = "shares"
|
65
|
+
|
66
|
+
def __new__(cls) -> Key:
|
67
|
+
"""Prevent instantiation."""
|
68
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
flwr/proto/error_pb2.py
ADDED
@@ -0,0 +1,26 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3
|
+
# source: flwr/proto/error.proto
|
4
|
+
# Protobuf Python Version: 4.25.0
|
5
|
+
"""Generated protocol buffer code."""
|
6
|
+
from google.protobuf import descriptor as _descriptor
|
7
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
8
|
+
from google.protobuf import symbol_database as _symbol_database
|
9
|
+
from google.protobuf.internal import builder as _builder
|
10
|
+
# @@protoc_insertion_point(imports)
|
11
|
+
|
12
|
+
_sym_db = _symbol_database.Default()
|
13
|
+
|
14
|
+
|
15
|
+
|
16
|
+
|
17
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/error.proto\x12\nflwr.proto\"%\n\x05\x45rror\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x12\x12\x0e\n\x06reason\x18\x02 \x01(\tb\x06proto3')
|
18
|
+
|
19
|
+
_globals = globals()
|
20
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
21
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.error_pb2', _globals)
|
22
|
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
23
|
+
DESCRIPTOR._options = None
|
24
|
+
_globals['_ERROR']._serialized_start=38
|
25
|
+
_globals['_ERROR']._serialized_end=75
|
26
|
+
# @@protoc_insertion_point(module_scope)
|
flwr/proto/error_pb2.pyi
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
"""
|
2
|
+
@generated by mypy-protobuf. Do not edit manually!
|
3
|
+
isort:skip_file
|
4
|
+
"""
|
5
|
+
import builtins
|
6
|
+
import google.protobuf.descriptor
|
7
|
+
import google.protobuf.message
|
8
|
+
import typing
|
9
|
+
import typing_extensions
|
10
|
+
|
11
|
+
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
12
|
+
|
13
|
+
class Error(google.protobuf.message.Message):
|
14
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
15
|
+
CODE_FIELD_NUMBER: builtins.int
|
16
|
+
REASON_FIELD_NUMBER: builtins.int
|
17
|
+
code: builtins.int
|
18
|
+
reason: typing.Text
|
19
|
+
def __init__(self,
|
20
|
+
*,
|
21
|
+
code: builtins.int = ...,
|
22
|
+
reason: typing.Text = ...,
|
23
|
+
) -> None: ...
|
24
|
+
def ClearField(self, field_name: typing_extensions.Literal["code",b"code","reason",b"reason"]) -> None: ...
|
25
|
+
global___Error = Error
|
flwr/proto/task_pb2.py
CHANGED
@@ -15,19 +15,20 @@ _sym_db = _symbol_database.Default()
|
|
15
15
|
from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
16
16
|
from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
|
17
17
|
from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
|
18
|
+
from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2
|
18
19
|
|
19
20
|
|
20
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\
|
21
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
|
21
22
|
|
22
23
|
_globals = globals()
|
23
24
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
24
25
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals)
|
25
26
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
26
27
|
DESCRIPTOR._options = None
|
27
|
-
_globals['_TASK']._serialized_start=
|
28
|
-
_globals['_TASK']._serialized_end=
|
29
|
-
_globals['_TASKINS']._serialized_start=
|
30
|
-
_globals['_TASKINS']._serialized_end=
|
31
|
-
_globals['_TASKRES']._serialized_start=
|
32
|
-
_globals['_TASKRES']._serialized_end=
|
28
|
+
_globals['_TASK']._serialized_start=141
|
29
|
+
_globals['_TASK']._serialized_end=387
|
30
|
+
_globals['_TASKINS']._serialized_start=389
|
31
|
+
_globals['_TASKINS']._serialized_end=481
|
32
|
+
_globals['_TASKRES']._serialized_start=483
|
33
|
+
_globals['_TASKRES']._serialized_end=575
|
33
34
|
# @@protoc_insertion_point(module_scope)
|