flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. flwr/client/mod/__init__.py +3 -2
  2. flwr/client/mod/centraldp_mods.py +63 -2
  3. flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
  4. flwr/common/differential_privacy.py +77 -0
  5. flwr/common/differential_privacy_constants.py +1 -0
  6. flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
  7. flwr/proto/error_pb2.py +26 -0
  8. flwr/proto/error_pb2.pyi +25 -0
  9. flwr/proto/error_pb2_grpc.py +4 -0
  10. flwr/proto/error_pb2_grpc.pyi +4 -0
  11. flwr/proto/task_pb2.py +8 -7
  12. flwr/proto/task_pb2.pyi +7 -2
  13. flwr/server/__init__.py +4 -0
  14. flwr/server/app.py +8 -31
  15. flwr/server/client_proxy.py +5 -0
  16. flwr/server/compat/__init__.py +2 -0
  17. flwr/server/compat/app.py +7 -88
  18. flwr/server/compat/app_utils.py +102 -0
  19. flwr/server/compat/driver_client_proxy.py +22 -10
  20. flwr/server/compat/legacy_context.py +55 -0
  21. flwr/server/run_serverapp.py +1 -1
  22. flwr/server/server.py +18 -8
  23. flwr/server/strategy/__init__.py +24 -14
  24. flwr/server/strategy/dp_adaptive_clipping.py +449 -0
  25. flwr/server/strategy/dp_fixed_clipping.py +5 -7
  26. flwr/server/superlink/driver/driver_grpc.py +54 -0
  27. flwr/server/superlink/driver/driver_servicer.py +4 -4
  28. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
  29. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  30. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
  31. flwr/server/superlink/fleet/vce/vce_api.py +236 -16
  32. flwr/server/typing.py +1 -0
  33. flwr/server/workflow/__init__.py +22 -0
  34. flwr/server/workflow/default_workflows.py +357 -0
  35. flwr/simulation/__init__.py +3 -0
  36. flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
  37. flwr/simulation/run_simulation.py +177 -0
  38. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
  39. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
  40. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
  41. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
  42. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
@@ -15,12 +15,13 @@
15
15
  """Mods."""
16
16
 
17
17
 
18
- from .centraldp_mods import fixedclipping_mod
18
+ from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
19
19
  from .secure_aggregation.secaggplus_mod import secaggplus_mod
20
20
  from .utils import make_ffn
21
21
 
22
22
  __all__ = [
23
+ "adaptiveclipping_mod",
24
+ "fixedclipping_mod",
23
25
  "make_ffn",
24
26
  "secaggplus_mod",
25
- "fixedclipping_mod",
26
27
  ]
@@ -20,8 +20,11 @@ from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
20
20
  from flwr.common import recordset_compat as compat
21
21
  from flwr.common.constant import MESSAGE_TYPE_FIT
22
22
  from flwr.common.context import Context
23
- from flwr.common.differential_privacy import compute_clip_model_update
24
- from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM
23
+ from flwr.common.differential_privacy import (
24
+ compute_adaptive_clip_model_update,
25
+ compute_clip_model_update,
26
+ )
27
+ from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM, KEY_NORM_BIT
25
28
  from flwr.common.message import Message
26
29
 
27
30
 
@@ -74,3 +77,61 @@ def fixedclipping_mod(
74
77
  fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
75
78
  out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
76
79
  return out_msg
80
+
81
+
82
+ def adaptiveclipping_mod(
83
+ msg: Message, ctxt: Context, call_next: ClientAppCallable
84
+ ) -> Message:
85
+ """Client-side adaptive clipping modifier.
86
+
87
+ This mod needs to be used with the DifferentialPrivacyClientSideAdaptiveClipping
88
+ server-side strategy wrapper.
89
+
90
+ The wrapper sends the clipping_norm value to the client.
91
+
92
+ This mod clips the client model updates before sending them to the server.
93
+
94
+ It also sends KEY_NORM_BIT to the server for computing the new clipping value.
95
+
96
+ It operates on messages with type MESSAGE_TYPE_FIT.
97
+
98
+ Notes
99
+ -----
100
+ Consider the order of mods when using multiple.
101
+
102
+ Typically, adaptiveclipping_mod should be the last to operate on params.
103
+ """
104
+ if msg.metadata.message_type != MESSAGE_TYPE_FIT:
105
+ return call_next(msg, ctxt)
106
+
107
+ fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
108
+
109
+ if KEY_CLIPPING_NORM not in fit_ins.config:
110
+ raise KeyError(
111
+ f"The {KEY_CLIPPING_NORM} value is not supplied by the "
112
+ f"DifferentialPrivacyClientSideFixedClipping wrapper at"
113
+ f" the server side."
114
+ )
115
+ if not isinstance(fit_ins.config[KEY_CLIPPING_NORM], float):
116
+ raise ValueError(f"{KEY_CLIPPING_NORM} should be a float value.")
117
+ clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM])
118
+ server_to_client_params = parameters_to_ndarrays(fit_ins.parameters)
119
+
120
+ # Call inner app
121
+ out_msg = call_next(msg, ctxt)
122
+ fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True)
123
+
124
+ client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
125
+
126
+ # Clip the client update
127
+ norm_bit = compute_adaptive_clip_model_update(
128
+ client_to_server_params,
129
+ server_to_client_params,
130
+ clipping_norm,
131
+ )
132
+
133
+ fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
134
+
135
+ fit_res.metrics[KEY_NORM_BIT] = norm_bit
136
+ out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
137
+ return out_msg
@@ -52,31 +52,10 @@ from flwr.common.secure_aggregation.ndarrays_arithmetic import (
52
52
  )
53
53
  from flwr.common.secure_aggregation.quantization import quantize
54
54
  from flwr.common.secure_aggregation.secaggplus_constants import (
55
- KEY_ACTIVE_SECURE_ID_LIST,
56
- KEY_CIPHERTEXT_LIST,
57
- KEY_CLIPPING_RANGE,
58
- KEY_DEAD_SECURE_ID_LIST,
59
- KEY_DESTINATION_LIST,
60
- KEY_MASKED_PARAMETERS,
61
- KEY_MOD_RANGE,
62
- KEY_PUBLIC_KEY_1,
63
- KEY_PUBLIC_KEY_2,
64
- KEY_SAMPLE_NUMBER,
65
- KEY_SECURE_ID,
66
- KEY_SECURE_ID_LIST,
67
- KEY_SHARE_LIST,
68
- KEY_SHARE_NUMBER,
69
- KEY_SOURCE_LIST,
70
- KEY_STAGE,
71
- KEY_TARGET_RANGE,
72
- KEY_THRESHOLD,
73
55
  RECORD_KEY_CONFIGS,
74
56
  RECORD_KEY_STATE,
75
- STAGE_COLLECT_MASKED_INPUT,
76
- STAGE_SETUP,
77
- STAGE_SHARE_KEYS,
78
- STAGE_UNMASK,
79
- STAGES,
57
+ Key,
58
+ Stage,
80
59
  )
81
60
  from flwr.common.secure_aggregation.secaggplus_utils import (
82
61
  pseudo_rand_gen,
@@ -91,7 +70,7 @@ from flwr.common.typing import ConfigsRecordValues, FitRes
91
70
  class SecAggPlusState:
92
71
  """State of the SecAgg+ protocol."""
93
72
 
94
- current_stage: str = STAGE_UNMASK
73
+ current_stage: str = Stage.UNMASK
95
74
 
96
75
  sid: int = 0
97
76
  sample_num: int = 0
@@ -187,20 +166,20 @@ def secaggplus_mod(
187
166
  check_stage(state.current_stage, configs)
188
167
 
189
168
  # Update the current stage
190
- state.current_stage = cast(str, configs.pop(KEY_STAGE))
169
+ state.current_stage = cast(str, configs.pop(Key.STAGE))
191
170
 
192
171
  # Check the validity of the configs based on the current stage
193
172
  check_configs(state.current_stage, configs)
194
173
 
195
174
  # Execute
196
- if state.current_stage == STAGE_SETUP:
175
+ if state.current_stage == Stage.SETUP:
197
176
  res = _setup(state, configs)
198
- elif state.current_stage == STAGE_SHARE_KEYS:
177
+ elif state.current_stage == Stage.SHARE_KEYS:
199
178
  res = _share_keys(state, configs)
200
- elif state.current_stage == STAGE_COLLECT_MASKED_INPUT:
179
+ elif state.current_stage == Stage.COLLECT_MASKED_INPUT:
201
180
  fit = _get_fit_fn(msg, ctxt, call_next)
202
181
  res = _collect_masked_input(state, configs, fit)
203
- elif state.current_stage == STAGE_UNMASK:
182
+ elif state.current_stage == Stage.UNMASK:
204
183
  res = _unmask(state, configs)
205
184
  else:
206
185
  raise ValueError(f"Unknown secagg stage: {state.current_stage}")
@@ -215,28 +194,29 @@ def secaggplus_mod(
215
194
 
216
195
  def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
217
196
  """Check the validity of the next stage."""
218
- # Check the existence of KEY_STAGE
219
- if KEY_STAGE not in configs:
197
+ # Check the existence of Config.STAGE
198
+ if Key.STAGE not in configs:
220
199
  raise KeyError(
221
- f"The required key '{KEY_STAGE}' is missing from the input `named_values`."
200
+ f"The required key '{Key.STAGE}' is missing from the input `named_values`."
222
201
  )
223
202
 
224
- # Check the value type of the KEY_STAGE
225
- next_stage = configs[KEY_STAGE]
203
+ # Check the value type of the Config.STAGE
204
+ next_stage = configs[Key.STAGE]
226
205
  if not isinstance(next_stage, str):
227
206
  raise TypeError(
228
- f"The value for the key '{KEY_STAGE}' must be of type {str}, "
207
+ f"The value for the key '{Key.STAGE}' must be of type {str}, "
229
208
  f"but got {type(next_stage)} instead."
230
209
  )
231
210
 
232
211
  # Check the validity of the next stage
233
- if next_stage == STAGE_SETUP:
234
- if current_stage != STAGE_UNMASK:
212
+ if next_stage == Stage.SETUP:
213
+ if current_stage != Stage.UNMASK:
235
214
  log(WARNING, "Restart from the setup stage")
236
215
  # If stage is not "setup",
237
216
  # the stage from `named_values` should be the expected next stage
238
217
  else:
239
- expected_next_stage = STAGES[(STAGES.index(current_stage) + 1) % len(STAGES)]
218
+ stages = Stage.all()
219
+ expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
240
220
  if next_stage != expected_next_stage:
241
221
  raise ValueError(
242
222
  "Abort secure aggregation: "
@@ -248,20 +228,20 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
248
228
  def check_configs(stage: str, configs: ConfigsRecord) -> None:
249
229
  """Check the validity of the configs."""
250
230
  # Check `named_values` for the setup stage
251
- if stage == STAGE_SETUP:
231
+ if stage == Stage.SETUP:
252
232
  key_type_pairs = [
253
- (KEY_SAMPLE_NUMBER, int),
254
- (KEY_SECURE_ID, int),
255
- (KEY_SHARE_NUMBER, int),
256
- (KEY_THRESHOLD, int),
257
- (KEY_CLIPPING_RANGE, float),
258
- (KEY_TARGET_RANGE, int),
259
- (KEY_MOD_RANGE, int),
233
+ (Key.SAMPLE_NUMBER, int),
234
+ (Key.SECURE_ID, int),
235
+ (Key.SHARE_NUMBER, int),
236
+ (Key.THRESHOLD, int),
237
+ (Key.CLIPPING_RANGE, float),
238
+ (Key.TARGET_RANGE, int),
239
+ (Key.MOD_RANGE, int),
260
240
  ]
261
241
  for key, expected_type in key_type_pairs:
262
242
  if key not in configs:
263
243
  raise KeyError(
264
- f"Stage {STAGE_SETUP}: the required key '{key}' is "
244
+ f"Stage {Stage.SETUP}: the required key '{key}' is "
265
245
  "missing from the input `named_values`."
266
246
  )
267
247
  # Bool is a subclass of int in Python,
@@ -269,11 +249,11 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
269
249
  # pylint: disable-next=unidiomatic-typecheck
270
250
  if type(configs[key]) is not expected_type:
271
251
  raise TypeError(
272
- f"Stage {STAGE_SETUP}: The value for the key '{key}' "
252
+ f"Stage {Stage.SETUP}: The value for the key '{key}' "
273
253
  f"must be of type {expected_type}, "
274
254
  f"but got {type(configs[key])} instead."
275
255
  )
276
- elif stage == STAGE_SHARE_KEYS:
256
+ elif stage == Stage.SHARE_KEYS:
277
257
  for key, value in configs.items():
278
258
  if (
279
259
  not isinstance(value, list)
@@ -282,18 +262,18 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
282
262
  or not isinstance(value[1], bytes)
283
263
  ):
284
264
  raise TypeError(
285
- f"Stage {STAGE_SHARE_KEYS}: "
265
+ f"Stage {Stage.SHARE_KEYS}: "
286
266
  f"the value for the key '{key}' must be a list of two bytes."
287
267
  )
288
- elif stage == STAGE_COLLECT_MASKED_INPUT:
268
+ elif stage == Stage.COLLECT_MASKED_INPUT:
289
269
  key_type_pairs = [
290
- (KEY_CIPHERTEXT_LIST, bytes),
291
- (KEY_SOURCE_LIST, int),
270
+ (Key.CIPHERTEXT_LIST, bytes),
271
+ (Key.SOURCE_LIST, int),
292
272
  ]
293
273
  for key, expected_type in key_type_pairs:
294
274
  if key not in configs:
295
275
  raise KeyError(
296
- f"Stage {STAGE_COLLECT_MASKED_INPUT}: "
276
+ f"Stage {Stage.COLLECT_MASKED_INPUT}: "
297
277
  f"the required key '{key}' is "
298
278
  "missing from the input `named_values`."
299
279
  )
@@ -304,19 +284,19 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
304
284
  if type(elm) is not expected_type
305
285
  ):
306
286
  raise TypeError(
307
- f"Stage {STAGE_COLLECT_MASKED_INPUT}: "
287
+ f"Stage {Stage.COLLECT_MASKED_INPUT}: "
308
288
  f"the value for the key '{key}' "
309
289
  f"must be of type List[{expected_type.__name__}]"
310
290
  )
311
- elif stage == STAGE_UNMASK:
291
+ elif stage == Stage.UNMASK:
312
292
  key_type_pairs = [
313
- (KEY_ACTIVE_SECURE_ID_LIST, int),
314
- (KEY_DEAD_SECURE_ID_LIST, int),
293
+ (Key.ACTIVE_SECURE_ID_LIST, int),
294
+ (Key.DEAD_SECURE_ID_LIST, int),
315
295
  ]
316
296
  for key, expected_type in key_type_pairs:
317
297
  if key not in configs:
318
298
  raise KeyError(
319
- f"Stage {STAGE_UNMASK}: "
299
+ f"Stage {Stage.UNMASK}: "
320
300
  f"the required key '{key}' is "
321
301
  "missing from the input `named_values`."
322
302
  )
@@ -327,7 +307,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
327
307
  if type(elm) is not expected_type
328
308
  ):
329
309
  raise TypeError(
330
- f"Stage {STAGE_UNMASK}: "
310
+ f"Stage {Stage.UNMASK}: "
331
311
  f"the value for the key '{key}' "
332
312
  f"must be of type List[{expected_type.__name__}]"
333
313
  )
@@ -340,15 +320,15 @@ def _setup(
340
320
  ) -> Dict[str, ConfigsRecordValues]:
341
321
  # Assigning parameter values to object fields
342
322
  sec_agg_param_dict = configs
343
- state.sample_num = cast(int, sec_agg_param_dict[KEY_SAMPLE_NUMBER])
344
- state.sid = cast(int, sec_agg_param_dict[KEY_SECURE_ID])
323
+ state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
324
+ state.sid = cast(int, sec_agg_param_dict[Key.SECURE_ID])
345
325
  log(INFO, "Client %d: starting stage 0...", state.sid)
346
326
 
347
- state.share_num = cast(int, sec_agg_param_dict[KEY_SHARE_NUMBER])
348
- state.threshold = cast(int, sec_agg_param_dict[KEY_THRESHOLD])
349
- state.clipping_range = cast(float, sec_agg_param_dict[KEY_CLIPPING_RANGE])
350
- state.target_range = cast(int, sec_agg_param_dict[KEY_TARGET_RANGE])
351
- state.mod_range = cast(int, sec_agg_param_dict[KEY_MOD_RANGE])
327
+ state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
328
+ state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
329
+ state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE])
330
+ state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE])
331
+ state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE])
352
332
 
353
333
  # Dictionaries containing client secure IDs as keys
354
334
  # and their respective secret shares as values.
@@ -367,7 +347,7 @@ def _setup(
367
347
  state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
368
348
  state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
369
349
  log(INFO, "Client %d: stage 0 completes. uploading public keys...", state.sid)
370
- return {KEY_PUBLIC_KEY_1: state.pk1, KEY_PUBLIC_KEY_2: state.pk2}
350
+ return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
371
351
 
372
352
 
373
353
  # pylint: disable-next=too-many-locals
@@ -429,7 +409,7 @@ def _share_keys(
429
409
  ciphertexts.append(ciphertext)
430
410
 
431
411
  log(INFO, "Client %d: stage 1 completes. uploading key shares...", state.sid)
432
- return {KEY_DESTINATION_LIST: dsts, KEY_CIPHERTEXT_LIST: ciphertexts}
412
+ return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
433
413
 
434
414
 
435
415
  # pylint: disable-next=too-many-locals
@@ -440,8 +420,8 @@ def _collect_masked_input(
440
420
  ) -> Dict[str, ConfigsRecordValues]:
441
421
  log(INFO, "Client %d: starting stage 2...", state.sid)
442
422
  available_clients: List[int] = []
443
- ciphertexts = cast(List[bytes], configs[KEY_CIPHERTEXT_LIST])
444
- srcs = cast(List[int], configs[KEY_SOURCE_LIST])
423
+ ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
424
+ srcs = cast(List[int], configs[Key.SOURCE_LIST])
445
425
  if len(ciphertexts) + 1 < state.threshold:
446
426
  raise ValueError("Not enough available neighbour clients.")
447
427
 
@@ -505,7 +485,7 @@ def _collect_masked_input(
505
485
  quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
506
486
  log(INFO, "Client %d: stage 2 completes. uploading masked parameters...", state.sid)
507
487
  return {
508
- KEY_MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
488
+ Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
509
489
  }
510
490
 
511
491
 
@@ -514,8 +494,8 @@ def _unmask(
514
494
  ) -> Dict[str, ConfigsRecordValues]:
515
495
  log(INFO, "Client %d: starting stage 3...", state.sid)
516
496
 
517
- active_sids = cast(List[int], configs[KEY_ACTIVE_SECURE_ID_LIST])
518
- dead_sids = cast(List[int], configs[KEY_DEAD_SECURE_ID_LIST])
497
+ active_sids = cast(List[int], configs[Key.ACTIVE_SECURE_ID_LIST])
498
+ dead_sids = cast(List[int], configs[Key.DEAD_SECURE_ID_LIST])
519
499
  # Send private mask seed share for every avaliable client (including itclient)
520
500
  # Send first private key share for building pairwise mask for every dropped client
521
501
  if len(active_sids) < state.threshold:
@@ -528,4 +508,4 @@ def _unmask(
528
508
  shares += [state.sk1_share_dict[sid] for sid in dead_sids]
529
509
 
530
510
  log(INFO, "Client %d: stage 3 completes. uploading key shares...", state.sid)
531
- return {KEY_SECURE_ID_LIST: sids, KEY_SHARE_LIST: shares}
511
+ return {Key.SECURE_ID_LIST: sids, Key.SHARE_LIST: shares}
@@ -15,6 +15,9 @@
15
15
  """Utility functions for differential privacy."""
16
16
 
17
17
 
18
+ from logging import WARNING
19
+ from typing import Optional, Tuple
20
+
18
21
  import numpy as np
19
22
 
20
23
  from flwr.common import (
@@ -23,6 +26,7 @@ from flwr.common import (
23
26
  ndarrays_to_parameters,
24
27
  parameters_to_ndarrays,
25
28
  )
29
+ from flwr.common.logger import log
26
30
 
27
31
 
28
32
  def get_norm(input_arrays: NDArrays) -> float:
@@ -72,6 +76,36 @@ def compute_clip_model_update(
72
76
  param1[i] = param2[i] + model_update[i]
73
77
 
74
78
 
79
+ def adaptive_clip_inputs_inplace(input_arrays: NDArrays, clipping_norm: float) -> bool:
80
+ """Clip model update based on the clipping norm in-place.
81
+
82
+ It returns true if scaling_factor < 1 which is used for norm_bit
83
+ FlatClip method of the paper: https://arxiv.org/abs/1710.06963
84
+ """
85
+ input_norm = get_norm(input_arrays)
86
+ scaling_factor = min(1, clipping_norm / input_norm)
87
+ for array in input_arrays:
88
+ array *= scaling_factor
89
+ return scaling_factor < 1
90
+
91
+
92
+ def compute_adaptive_clip_model_update(
93
+ param1: NDArrays, param2: NDArrays, clipping_norm: float
94
+ ) -> bool:
95
+ """Compute model update, clip it, then add the clipped value to param1.
96
+
97
+ model update = param1 - param2
98
+ Return the norm_bit
99
+ """
100
+ model_update = [np.subtract(x, y) for (x, y) in zip(param1, param2)]
101
+ norm_bit = adaptive_clip_inputs_inplace(model_update, clipping_norm)
102
+
103
+ for i, _ in enumerate(param2):
104
+ param1[i] = param2[i] + model_update[i]
105
+
106
+ return norm_bit
107
+
108
+
75
109
  def add_gaussian_noise_to_params(
76
110
  model_params: Parameters,
77
111
  noise_multiplier: float,
@@ -85,3 +119,46 @@ def add_gaussian_noise_to_params(
85
119
  compute_stdv(noise_multiplier, clipping_norm, num_sampled_clients),
86
120
  )
87
121
  return ndarrays_to_parameters(model_params_ndarrays)
122
+
123
+
124
+ def compute_adaptive_noise_params(
125
+ noise_multiplier: float,
126
+ num_sampled_clients: float,
127
+ clipped_count_stddev: Optional[float],
128
+ ) -> Tuple[float, float]:
129
+ """Compute noising parameters for the adaptive clipping.
130
+
131
+ Paper: https://arxiv.org/abs/1905.03871
132
+ """
133
+ if noise_multiplier > 0:
134
+ if clipped_count_stddev is None:
135
+ clipped_count_stddev = num_sampled_clients / 20
136
+ if noise_multiplier >= 2 * clipped_count_stddev:
137
+ raise ValueError(
138
+ f"If not specified, `clipped_count_stddev` is set to "
139
+ f"`num_sampled_clients`/20 by default. This value "
140
+ f"({num_sampled_clients / 20}) is too low to achieve the "
141
+ f"desired effective `noise_multiplier` ({noise_multiplier}). "
142
+ f"Consider increasing `clipped_count_stddev` or decreasing "
143
+ f"`noise_multiplier`."
144
+ )
145
+ noise_multiplier_value = (
146
+ noise_multiplier ** (-2) - (2 * clipped_count_stddev) ** (-2)
147
+ ) ** -0.5
148
+
149
+ adding_noise = noise_multiplier_value / noise_multiplier
150
+ if adding_noise >= 2:
151
+ log(
152
+ WARNING,
153
+ "A significant amount of noise (%s) has to be "
154
+ "added. Consider increasing `clipped_count_stddev` or "
155
+ "`num_sampled_clients`.",
156
+ adding_noise,
157
+ )
158
+
159
+ else:
160
+ if clipped_count_stddev is None:
161
+ clipped_count_stddev = 0.0
162
+ noise_multiplier_value = 0.0
163
+
164
+ return clipped_count_stddev, noise_multiplier_value
@@ -16,6 +16,7 @@
16
16
 
17
17
 
18
18
  KEY_CLIPPING_NORM = "clipping_norm"
19
+ KEY_NORM_BIT = "norm_bit"
19
20
  CLIENTS_DISCREPANCY_WARNING = (
20
21
  "The number of clients returning parameters (%s)"
21
22
  " differs from the number of sampled clients (%s)."
@@ -14,33 +14,55 @@
14
14
  # ==============================================================================
15
15
  """Constants for the SecAgg/SecAgg+ protocol."""
16
16
 
17
+
18
+ from __future__ import annotations
19
+
17
20
  RECORD_KEY_STATE = "secaggplus_state"
18
21
  RECORD_KEY_CONFIGS = "secaggplus_configs"
19
22
 
20
- # Names of stages
21
- STAGE_SETUP = "setup"
22
- STAGE_SHARE_KEYS = "share_keys"
23
- STAGE_COLLECT_MASKED_INPUT = "collect_masked_input"
24
- STAGE_UNMASK = "unmask"
25
- STAGES = (STAGE_SETUP, STAGE_SHARE_KEYS, STAGE_COLLECT_MASKED_INPUT, STAGE_UNMASK)
26
-
27
- # All valid keys in received/replied `named_values` dictionaries
28
- KEY_STAGE = "stage"
29
- KEY_SAMPLE_NUMBER = "sample_num"
30
- KEY_SECURE_ID = "secure_id"
31
- KEY_SHARE_NUMBER = "share_num"
32
- KEY_THRESHOLD = "threshold"
33
- KEY_CLIPPING_RANGE = "clipping_range"
34
- KEY_TARGET_RANGE = "target_range"
35
- KEY_MOD_RANGE = "mod_range"
36
- KEY_PUBLIC_KEY_1 = "pk1"
37
- KEY_PUBLIC_KEY_2 = "pk2"
38
- KEY_DESTINATION_LIST = "dsts"
39
- KEY_CIPHERTEXT_LIST = "ctxts"
40
- KEY_SOURCE_LIST = "srcs"
41
- KEY_PARAMETERS = "params"
42
- KEY_MASKED_PARAMETERS = "masked_params"
43
- KEY_ACTIVE_SECURE_ID_LIST = "active_sids"
44
- KEY_DEAD_SECURE_ID_LIST = "dead_sids"
45
- KEY_SECURE_ID_LIST = "sids"
46
- KEY_SHARE_LIST = "shares"
23
+
24
+ class Stage:
25
+ """Stages for the SecAgg+ protocol."""
26
+
27
+ SETUP = "setup"
28
+ SHARE_KEYS = "share_keys"
29
+ COLLECT_MASKED_INPUT = "collect_masked_input"
30
+ UNMASK = "unmask"
31
+ _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_INPUT, UNMASK)
32
+
33
+ @classmethod
34
+ def all(cls) -> tuple[str, str, str, str]:
35
+ """Return all stages."""
36
+ return cls._stages
37
+
38
+ def __new__(cls) -> Stage:
39
+ """Prevent instantiation."""
40
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
41
+
42
+
43
+ class Key:
44
+ """Keys for the configs in the ConfigsRecord."""
45
+
46
+ STAGE = "stage"
47
+ SAMPLE_NUMBER = "sample_num"
48
+ SECURE_ID = "secure_id"
49
+ SHARE_NUMBER = "share_num"
50
+ THRESHOLD = "threshold"
51
+ CLIPPING_RANGE = "clipping_range"
52
+ TARGET_RANGE = "target_range"
53
+ MOD_RANGE = "mod_range"
54
+ PUBLIC_KEY_1 = "pk1"
55
+ PUBLIC_KEY_2 = "pk2"
56
+ DESTINATION_LIST = "dsts"
57
+ CIPHERTEXT_LIST = "ctxts"
58
+ SOURCE_LIST = "srcs"
59
+ PARAMETERS = "params"
60
+ MASKED_PARAMETERS = "masked_params"
61
+ ACTIVE_SECURE_ID_LIST = "active_sids"
62
+ DEAD_SECURE_ID_LIST = "dead_sids"
63
+ SECURE_ID_LIST = "sids"
64
+ SHARE_LIST = "shares"
65
+
66
+ def __new__(cls) -> Key:
67
+ """Prevent instantiation."""
68
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
@@ -0,0 +1,26 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: flwr/proto/error.proto
4
+ # Protobuf Python Version: 4.25.0
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+
16
+
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/error.proto\x12\nflwr.proto\"%\n\x05\x45rror\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x12\x12\x0e\n\x06reason\x18\x02 \x01(\tb\x06proto3')
18
+
19
+ _globals = globals()
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.error_pb2', _globals)
22
+ if _descriptor._USE_C_DESCRIPTORS == False:
23
+ DESCRIPTOR._options = None
24
+ _globals['_ERROR']._serialized_start=38
25
+ _globals['_ERROR']._serialized_end=75
26
+ # @@protoc_insertion_point(module_scope)
@@ -0,0 +1,25 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
5
+ import builtins
6
+ import google.protobuf.descriptor
7
+ import google.protobuf.message
8
+ import typing
9
+ import typing_extensions
10
+
11
+ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
12
+
13
+ class Error(google.protobuf.message.Message):
14
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
15
+ CODE_FIELD_NUMBER: builtins.int
16
+ REASON_FIELD_NUMBER: builtins.int
17
+ code: builtins.int
18
+ reason: typing.Text
19
+ def __init__(self,
20
+ *,
21
+ code: builtins.int = ...,
22
+ reason: typing.Text = ...,
23
+ ) -> None: ...
24
+ def ClearField(self, field_name: typing_extensions.Literal["code",b"code","reason",b"reason"]) -> None: ...
25
+ global___Error = Error
@@ -0,0 +1,4 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+
@@ -0,0 +1,4 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
flwr/proto/task_pb2.py CHANGED
@@ -15,19 +15,20 @@ _sym_db = _symbol_database.Default()
15
15
  from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
16
  from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
17
17
  from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
18
+ from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2
18
19
 
19
20
 
20
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd4\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
21
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
21
22
 
22
23
  _globals = globals()
23
24
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
24
25
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals)
25
26
  if _descriptor._USE_C_DESCRIPTORS == False:
26
27
  DESCRIPTOR._options = None
27
- _globals['_TASK']._serialized_start=117
28
- _globals['_TASK']._serialized_end=329
29
- _globals['_TASKINS']._serialized_start=331
30
- _globals['_TASKINS']._serialized_end=423
31
- _globals['_TASKRES']._serialized_start=425
32
- _globals['_TASKRES']._serialized_end=517
28
+ _globals['_TASK']._serialized_start=141
29
+ _globals['_TASK']._serialized_end=387
30
+ _globals['_TASKINS']._serialized_start=389
31
+ _globals['_TASKINS']._serialized_end=481
32
+ _globals['_TASKRES']._serialized_start=483
33
+ _globals['_TASKRES']._serialized_end=575
33
34
  # @@protoc_insertion_point(module_scope)