flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (42) hide show
  1. flwr/client/mod/__init__.py +3 -2
  2. flwr/client/mod/centraldp_mods.py +63 -2
  3. flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
  4. flwr/common/differential_privacy.py +77 -0
  5. flwr/common/differential_privacy_constants.py +1 -0
  6. flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
  7. flwr/proto/error_pb2.py +26 -0
  8. flwr/proto/error_pb2.pyi +25 -0
  9. flwr/proto/error_pb2_grpc.py +4 -0
  10. flwr/proto/error_pb2_grpc.pyi +4 -0
  11. flwr/proto/task_pb2.py +8 -7
  12. flwr/proto/task_pb2.pyi +7 -2
  13. flwr/server/__init__.py +4 -0
  14. flwr/server/app.py +8 -31
  15. flwr/server/client_proxy.py +5 -0
  16. flwr/server/compat/__init__.py +2 -0
  17. flwr/server/compat/app.py +7 -88
  18. flwr/server/compat/app_utils.py +102 -0
  19. flwr/server/compat/driver_client_proxy.py +22 -10
  20. flwr/server/compat/legacy_context.py +55 -0
  21. flwr/server/run_serverapp.py +1 -1
  22. flwr/server/server.py +18 -8
  23. flwr/server/strategy/__init__.py +24 -14
  24. flwr/server/strategy/dp_adaptive_clipping.py +449 -0
  25. flwr/server/strategy/dp_fixed_clipping.py +5 -7
  26. flwr/server/superlink/driver/driver_grpc.py +54 -0
  27. flwr/server/superlink/driver/driver_servicer.py +4 -4
  28. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
  29. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  30. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
  31. flwr/server/superlink/fleet/vce/vce_api.py +236 -16
  32. flwr/server/typing.py +1 -0
  33. flwr/server/workflow/__init__.py +22 -0
  34. flwr/server/workflow/default_workflows.py +357 -0
  35. flwr/simulation/__init__.py +3 -0
  36. flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
  37. flwr/simulation/run_simulation.py +177 -0
  38. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
  39. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
  40. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
  41. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
  42. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
@@ -15,12 +15,13 @@
15
15
  """Mods."""
16
16
 
17
17
 
18
- from .centraldp_mods import fixedclipping_mod
18
+ from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
19
19
  from .secure_aggregation.secaggplus_mod import secaggplus_mod
20
20
  from .utils import make_ffn
21
21
 
22
22
  __all__ = [
23
+ "adaptiveclipping_mod",
24
+ "fixedclipping_mod",
23
25
  "make_ffn",
24
26
  "secaggplus_mod",
25
- "fixedclipping_mod",
26
27
  ]
@@ -20,8 +20,11 @@ from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
20
20
  from flwr.common import recordset_compat as compat
21
21
  from flwr.common.constant import MESSAGE_TYPE_FIT
22
22
  from flwr.common.context import Context
23
- from flwr.common.differential_privacy import compute_clip_model_update
24
- from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM
23
+ from flwr.common.differential_privacy import (
24
+ compute_adaptive_clip_model_update,
25
+ compute_clip_model_update,
26
+ )
27
+ from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM, KEY_NORM_BIT
25
28
  from flwr.common.message import Message
26
29
 
27
30
 
@@ -74,3 +77,61 @@ def fixedclipping_mod(
74
77
  fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
75
78
  out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
76
79
  return out_msg
80
+
81
+
82
+ def adaptiveclipping_mod(
83
+ msg: Message, ctxt: Context, call_next: ClientAppCallable
84
+ ) -> Message:
85
+ """Client-side adaptive clipping modifier.
86
+
87
+ This mod needs to be used with the DifferentialPrivacyClientSideAdaptiveClipping
88
+ server-side strategy wrapper.
89
+
90
+ The wrapper sends the clipping_norm value to the client.
91
+
92
+ This mod clips the client model updates before sending them to the server.
93
+
94
+ It also sends KEY_NORM_BIT to the server for computing the new clipping value.
95
+
96
+ It operates on messages with type MESSAGE_TYPE_FIT.
97
+
98
+ Notes
99
+ -----
100
+ Consider the order of mods when using multiple.
101
+
102
+ Typically, adaptiveclipping_mod should be the last to operate on params.
103
+ """
104
+ if msg.metadata.message_type != MESSAGE_TYPE_FIT:
105
+ return call_next(msg, ctxt)
106
+
107
+ fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
108
+
109
+ if KEY_CLIPPING_NORM not in fit_ins.config:
110
+ raise KeyError(
111
+ f"The {KEY_CLIPPING_NORM} value is not supplied by the "
112
+ f"DifferentialPrivacyClientSideFixedClipping wrapper at"
113
+ f" the server side."
114
+ )
115
+ if not isinstance(fit_ins.config[KEY_CLIPPING_NORM], float):
116
+ raise ValueError(f"{KEY_CLIPPING_NORM} should be a float value.")
117
+ clipping_norm = float(fit_ins.config[KEY_CLIPPING_NORM])
118
+ server_to_client_params = parameters_to_ndarrays(fit_ins.parameters)
119
+
120
+ # Call inner app
121
+ out_msg = call_next(msg, ctxt)
122
+ fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True)
123
+
124
+ client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
125
+
126
+ # Clip the client update
127
+ norm_bit = compute_adaptive_clip_model_update(
128
+ client_to_server_params,
129
+ server_to_client_params,
130
+ clipping_norm,
131
+ )
132
+
133
+ fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
134
+
135
+ fit_res.metrics[KEY_NORM_BIT] = norm_bit
136
+ out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
137
+ return out_msg
@@ -52,31 +52,10 @@ from flwr.common.secure_aggregation.ndarrays_arithmetic import (
52
52
  )
53
53
  from flwr.common.secure_aggregation.quantization import quantize
54
54
  from flwr.common.secure_aggregation.secaggplus_constants import (
55
- KEY_ACTIVE_SECURE_ID_LIST,
56
- KEY_CIPHERTEXT_LIST,
57
- KEY_CLIPPING_RANGE,
58
- KEY_DEAD_SECURE_ID_LIST,
59
- KEY_DESTINATION_LIST,
60
- KEY_MASKED_PARAMETERS,
61
- KEY_MOD_RANGE,
62
- KEY_PUBLIC_KEY_1,
63
- KEY_PUBLIC_KEY_2,
64
- KEY_SAMPLE_NUMBER,
65
- KEY_SECURE_ID,
66
- KEY_SECURE_ID_LIST,
67
- KEY_SHARE_LIST,
68
- KEY_SHARE_NUMBER,
69
- KEY_SOURCE_LIST,
70
- KEY_STAGE,
71
- KEY_TARGET_RANGE,
72
- KEY_THRESHOLD,
73
55
  RECORD_KEY_CONFIGS,
74
56
  RECORD_KEY_STATE,
75
- STAGE_COLLECT_MASKED_INPUT,
76
- STAGE_SETUP,
77
- STAGE_SHARE_KEYS,
78
- STAGE_UNMASK,
79
- STAGES,
57
+ Key,
58
+ Stage,
80
59
  )
81
60
  from flwr.common.secure_aggregation.secaggplus_utils import (
82
61
  pseudo_rand_gen,
@@ -91,7 +70,7 @@ from flwr.common.typing import ConfigsRecordValues, FitRes
91
70
  class SecAggPlusState:
92
71
  """State of the SecAgg+ protocol."""
93
72
 
94
- current_stage: str = STAGE_UNMASK
73
+ current_stage: str = Stage.UNMASK
95
74
 
96
75
  sid: int = 0
97
76
  sample_num: int = 0
@@ -187,20 +166,20 @@ def secaggplus_mod(
187
166
  check_stage(state.current_stage, configs)
188
167
 
189
168
  # Update the current stage
190
- state.current_stage = cast(str, configs.pop(KEY_STAGE))
169
+ state.current_stage = cast(str, configs.pop(Key.STAGE))
191
170
 
192
171
  # Check the validity of the configs based on the current stage
193
172
  check_configs(state.current_stage, configs)
194
173
 
195
174
  # Execute
196
- if state.current_stage == STAGE_SETUP:
175
+ if state.current_stage == Stage.SETUP:
197
176
  res = _setup(state, configs)
198
- elif state.current_stage == STAGE_SHARE_KEYS:
177
+ elif state.current_stage == Stage.SHARE_KEYS:
199
178
  res = _share_keys(state, configs)
200
- elif state.current_stage == STAGE_COLLECT_MASKED_INPUT:
179
+ elif state.current_stage == Stage.COLLECT_MASKED_INPUT:
201
180
  fit = _get_fit_fn(msg, ctxt, call_next)
202
181
  res = _collect_masked_input(state, configs, fit)
203
- elif state.current_stage == STAGE_UNMASK:
182
+ elif state.current_stage == Stage.UNMASK:
204
183
  res = _unmask(state, configs)
205
184
  else:
206
185
  raise ValueError(f"Unknown secagg stage: {state.current_stage}")
@@ -215,28 +194,29 @@ def secaggplus_mod(
215
194
 
216
195
  def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
217
196
  """Check the validity of the next stage."""
218
- # Check the existence of KEY_STAGE
219
- if KEY_STAGE not in configs:
197
+ # Check the existence of Config.STAGE
198
+ if Key.STAGE not in configs:
220
199
  raise KeyError(
221
- f"The required key '{KEY_STAGE}' is missing from the input `named_values`."
200
+ f"The required key '{Key.STAGE}' is missing from the input `named_values`."
222
201
  )
223
202
 
224
- # Check the value type of the KEY_STAGE
225
- next_stage = configs[KEY_STAGE]
203
+ # Check the value type of the Config.STAGE
204
+ next_stage = configs[Key.STAGE]
226
205
  if not isinstance(next_stage, str):
227
206
  raise TypeError(
228
- f"The value for the key '{KEY_STAGE}' must be of type {str}, "
207
+ f"The value for the key '{Key.STAGE}' must be of type {str}, "
229
208
  f"but got {type(next_stage)} instead."
230
209
  )
231
210
 
232
211
  # Check the validity of the next stage
233
- if next_stage == STAGE_SETUP:
234
- if current_stage != STAGE_UNMASK:
212
+ if next_stage == Stage.SETUP:
213
+ if current_stage != Stage.UNMASK:
235
214
  log(WARNING, "Restart from the setup stage")
236
215
  # If stage is not "setup",
237
216
  # the stage from `named_values` should be the expected next stage
238
217
  else:
239
- expected_next_stage = STAGES[(STAGES.index(current_stage) + 1) % len(STAGES)]
218
+ stages = Stage.all()
219
+ expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
240
220
  if next_stage != expected_next_stage:
241
221
  raise ValueError(
242
222
  "Abort secure aggregation: "
@@ -248,20 +228,20 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
248
228
  def check_configs(stage: str, configs: ConfigsRecord) -> None:
249
229
  """Check the validity of the configs."""
250
230
  # Check `named_values` for the setup stage
251
- if stage == STAGE_SETUP:
231
+ if stage == Stage.SETUP:
252
232
  key_type_pairs = [
253
- (KEY_SAMPLE_NUMBER, int),
254
- (KEY_SECURE_ID, int),
255
- (KEY_SHARE_NUMBER, int),
256
- (KEY_THRESHOLD, int),
257
- (KEY_CLIPPING_RANGE, float),
258
- (KEY_TARGET_RANGE, int),
259
- (KEY_MOD_RANGE, int),
233
+ (Key.SAMPLE_NUMBER, int),
234
+ (Key.SECURE_ID, int),
235
+ (Key.SHARE_NUMBER, int),
236
+ (Key.THRESHOLD, int),
237
+ (Key.CLIPPING_RANGE, float),
238
+ (Key.TARGET_RANGE, int),
239
+ (Key.MOD_RANGE, int),
260
240
  ]
261
241
  for key, expected_type in key_type_pairs:
262
242
  if key not in configs:
263
243
  raise KeyError(
264
- f"Stage {STAGE_SETUP}: the required key '{key}' is "
244
+ f"Stage {Stage.SETUP}: the required key '{key}' is "
265
245
  "missing from the input `named_values`."
266
246
  )
267
247
  # Bool is a subclass of int in Python,
@@ -269,11 +249,11 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
269
249
  # pylint: disable-next=unidiomatic-typecheck
270
250
  if type(configs[key]) is not expected_type:
271
251
  raise TypeError(
272
- f"Stage {STAGE_SETUP}: The value for the key '{key}' "
252
+ f"Stage {Stage.SETUP}: The value for the key '{key}' "
273
253
  f"must be of type {expected_type}, "
274
254
  f"but got {type(configs[key])} instead."
275
255
  )
276
- elif stage == STAGE_SHARE_KEYS:
256
+ elif stage == Stage.SHARE_KEYS:
277
257
  for key, value in configs.items():
278
258
  if (
279
259
  not isinstance(value, list)
@@ -282,18 +262,18 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
282
262
  or not isinstance(value[1], bytes)
283
263
  ):
284
264
  raise TypeError(
285
- f"Stage {STAGE_SHARE_KEYS}: "
265
+ f"Stage {Stage.SHARE_KEYS}: "
286
266
  f"the value for the key '{key}' must be a list of two bytes."
287
267
  )
288
- elif stage == STAGE_COLLECT_MASKED_INPUT:
268
+ elif stage == Stage.COLLECT_MASKED_INPUT:
289
269
  key_type_pairs = [
290
- (KEY_CIPHERTEXT_LIST, bytes),
291
- (KEY_SOURCE_LIST, int),
270
+ (Key.CIPHERTEXT_LIST, bytes),
271
+ (Key.SOURCE_LIST, int),
292
272
  ]
293
273
  for key, expected_type in key_type_pairs:
294
274
  if key not in configs:
295
275
  raise KeyError(
296
- f"Stage {STAGE_COLLECT_MASKED_INPUT}: "
276
+ f"Stage {Stage.COLLECT_MASKED_INPUT}: "
297
277
  f"the required key '{key}' is "
298
278
  "missing from the input `named_values`."
299
279
  )
@@ -304,19 +284,19 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
304
284
  if type(elm) is not expected_type
305
285
  ):
306
286
  raise TypeError(
307
- f"Stage {STAGE_COLLECT_MASKED_INPUT}: "
287
+ f"Stage {Stage.COLLECT_MASKED_INPUT}: "
308
288
  f"the value for the key '{key}' "
309
289
  f"must be of type List[{expected_type.__name__}]"
310
290
  )
311
- elif stage == STAGE_UNMASK:
291
+ elif stage == Stage.UNMASK:
312
292
  key_type_pairs = [
313
- (KEY_ACTIVE_SECURE_ID_LIST, int),
314
- (KEY_DEAD_SECURE_ID_LIST, int),
293
+ (Key.ACTIVE_SECURE_ID_LIST, int),
294
+ (Key.DEAD_SECURE_ID_LIST, int),
315
295
  ]
316
296
  for key, expected_type in key_type_pairs:
317
297
  if key not in configs:
318
298
  raise KeyError(
319
- f"Stage {STAGE_UNMASK}: "
299
+ f"Stage {Stage.UNMASK}: "
320
300
  f"the required key '{key}' is "
321
301
  "missing from the input `named_values`."
322
302
  )
@@ -327,7 +307,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
327
307
  if type(elm) is not expected_type
328
308
  ):
329
309
  raise TypeError(
330
- f"Stage {STAGE_UNMASK}: "
310
+ f"Stage {Stage.UNMASK}: "
331
311
  f"the value for the key '{key}' "
332
312
  f"must be of type List[{expected_type.__name__}]"
333
313
  )
@@ -340,15 +320,15 @@ def _setup(
340
320
  ) -> Dict[str, ConfigsRecordValues]:
341
321
  # Assigning parameter values to object fields
342
322
  sec_agg_param_dict = configs
343
- state.sample_num = cast(int, sec_agg_param_dict[KEY_SAMPLE_NUMBER])
344
- state.sid = cast(int, sec_agg_param_dict[KEY_SECURE_ID])
323
+ state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
324
+ state.sid = cast(int, sec_agg_param_dict[Key.SECURE_ID])
345
325
  log(INFO, "Client %d: starting stage 0...", state.sid)
346
326
 
347
- state.share_num = cast(int, sec_agg_param_dict[KEY_SHARE_NUMBER])
348
- state.threshold = cast(int, sec_agg_param_dict[KEY_THRESHOLD])
349
- state.clipping_range = cast(float, sec_agg_param_dict[KEY_CLIPPING_RANGE])
350
- state.target_range = cast(int, sec_agg_param_dict[KEY_TARGET_RANGE])
351
- state.mod_range = cast(int, sec_agg_param_dict[KEY_MOD_RANGE])
327
+ state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
328
+ state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
329
+ state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE])
330
+ state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE])
331
+ state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE])
352
332
 
353
333
  # Dictionaries containing client secure IDs as keys
354
334
  # and their respective secret shares as values.
@@ -367,7 +347,7 @@ def _setup(
367
347
  state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
368
348
  state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
369
349
  log(INFO, "Client %d: stage 0 completes. uploading public keys...", state.sid)
370
- return {KEY_PUBLIC_KEY_1: state.pk1, KEY_PUBLIC_KEY_2: state.pk2}
350
+ return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
371
351
 
372
352
 
373
353
  # pylint: disable-next=too-many-locals
@@ -429,7 +409,7 @@ def _share_keys(
429
409
  ciphertexts.append(ciphertext)
430
410
 
431
411
  log(INFO, "Client %d: stage 1 completes. uploading key shares...", state.sid)
432
- return {KEY_DESTINATION_LIST: dsts, KEY_CIPHERTEXT_LIST: ciphertexts}
412
+ return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
433
413
 
434
414
 
435
415
  # pylint: disable-next=too-many-locals
@@ -440,8 +420,8 @@ def _collect_masked_input(
440
420
  ) -> Dict[str, ConfigsRecordValues]:
441
421
  log(INFO, "Client %d: starting stage 2...", state.sid)
442
422
  available_clients: List[int] = []
443
- ciphertexts = cast(List[bytes], configs[KEY_CIPHERTEXT_LIST])
444
- srcs = cast(List[int], configs[KEY_SOURCE_LIST])
423
+ ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
424
+ srcs = cast(List[int], configs[Key.SOURCE_LIST])
445
425
  if len(ciphertexts) + 1 < state.threshold:
446
426
  raise ValueError("Not enough available neighbour clients.")
447
427
 
@@ -505,7 +485,7 @@ def _collect_masked_input(
505
485
  quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
506
486
  log(INFO, "Client %d: stage 2 completes. uploading masked parameters...", state.sid)
507
487
  return {
508
- KEY_MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
488
+ Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
509
489
  }
510
490
 
511
491
 
@@ -514,8 +494,8 @@ def _unmask(
514
494
  ) -> Dict[str, ConfigsRecordValues]:
515
495
  log(INFO, "Client %d: starting stage 3...", state.sid)
516
496
 
517
- active_sids = cast(List[int], configs[KEY_ACTIVE_SECURE_ID_LIST])
518
- dead_sids = cast(List[int], configs[KEY_DEAD_SECURE_ID_LIST])
497
+ active_sids = cast(List[int], configs[Key.ACTIVE_SECURE_ID_LIST])
498
+ dead_sids = cast(List[int], configs[Key.DEAD_SECURE_ID_LIST])
519
499
  # Send private mask seed share for every avaliable client (including itclient)
520
500
  # Send first private key share for building pairwise mask for every dropped client
521
501
  if len(active_sids) < state.threshold:
@@ -528,4 +508,4 @@ def _unmask(
528
508
  shares += [state.sk1_share_dict[sid] for sid in dead_sids]
529
509
 
530
510
  log(INFO, "Client %d: stage 3 completes. uploading key shares...", state.sid)
531
- return {KEY_SECURE_ID_LIST: sids, KEY_SHARE_LIST: shares}
511
+ return {Key.SECURE_ID_LIST: sids, Key.SHARE_LIST: shares}
@@ -15,6 +15,9 @@
15
15
  """Utility functions for differential privacy."""
16
16
 
17
17
 
18
+ from logging import WARNING
19
+ from typing import Optional, Tuple
20
+
18
21
  import numpy as np
19
22
 
20
23
  from flwr.common import (
@@ -23,6 +26,7 @@ from flwr.common import (
23
26
  ndarrays_to_parameters,
24
27
  parameters_to_ndarrays,
25
28
  )
29
+ from flwr.common.logger import log
26
30
 
27
31
 
28
32
  def get_norm(input_arrays: NDArrays) -> float:
@@ -72,6 +76,36 @@ def compute_clip_model_update(
72
76
  param1[i] = param2[i] + model_update[i]
73
77
 
74
78
 
79
+ def adaptive_clip_inputs_inplace(input_arrays: NDArrays, clipping_norm: float) -> bool:
80
+ """Clip model update based on the clipping norm in-place.
81
+
82
+ It returns true if scaling_factor < 1 which is used for norm_bit
83
+ FlatClip method of the paper: https://arxiv.org/abs/1710.06963
84
+ """
85
+ input_norm = get_norm(input_arrays)
86
+ scaling_factor = min(1, clipping_norm / input_norm)
87
+ for array in input_arrays:
88
+ array *= scaling_factor
89
+ return scaling_factor < 1
90
+
91
+
92
+ def compute_adaptive_clip_model_update(
93
+ param1: NDArrays, param2: NDArrays, clipping_norm: float
94
+ ) -> bool:
95
+ """Compute model update, clip it, then add the clipped value to param1.
96
+
97
+ model update = param1 - param2
98
+ Return the norm_bit
99
+ """
100
+ model_update = [np.subtract(x, y) for (x, y) in zip(param1, param2)]
101
+ norm_bit = adaptive_clip_inputs_inplace(model_update, clipping_norm)
102
+
103
+ for i, _ in enumerate(param2):
104
+ param1[i] = param2[i] + model_update[i]
105
+
106
+ return norm_bit
107
+
108
+
75
109
  def add_gaussian_noise_to_params(
76
110
  model_params: Parameters,
77
111
  noise_multiplier: float,
@@ -85,3 +119,46 @@ def add_gaussian_noise_to_params(
85
119
  compute_stdv(noise_multiplier, clipping_norm, num_sampled_clients),
86
120
  )
87
121
  return ndarrays_to_parameters(model_params_ndarrays)
122
+
123
+
124
+ def compute_adaptive_noise_params(
125
+ noise_multiplier: float,
126
+ num_sampled_clients: float,
127
+ clipped_count_stddev: Optional[float],
128
+ ) -> Tuple[float, float]:
129
+ """Compute noising parameters for the adaptive clipping.
130
+
131
+ Paper: https://arxiv.org/abs/1905.03871
132
+ """
133
+ if noise_multiplier > 0:
134
+ if clipped_count_stddev is None:
135
+ clipped_count_stddev = num_sampled_clients / 20
136
+ if noise_multiplier >= 2 * clipped_count_stddev:
137
+ raise ValueError(
138
+ f"If not specified, `clipped_count_stddev` is set to "
139
+ f"`num_sampled_clients`/20 by default. This value "
140
+ f"({num_sampled_clients / 20}) is too low to achieve the "
141
+ f"desired effective `noise_multiplier` ({noise_multiplier}). "
142
+ f"Consider increasing `clipped_count_stddev` or decreasing "
143
+ f"`noise_multiplier`."
144
+ )
145
+ noise_multiplier_value = (
146
+ noise_multiplier ** (-2) - (2 * clipped_count_stddev) ** (-2)
147
+ ) ** -0.5
148
+
149
+ adding_noise = noise_multiplier_value / noise_multiplier
150
+ if adding_noise >= 2:
151
+ log(
152
+ WARNING,
153
+ "A significant amount of noise (%s) has to be "
154
+ "added. Consider increasing `clipped_count_stddev` or "
155
+ "`num_sampled_clients`.",
156
+ adding_noise,
157
+ )
158
+
159
+ else:
160
+ if clipped_count_stddev is None:
161
+ clipped_count_stddev = 0.0
162
+ noise_multiplier_value = 0.0
163
+
164
+ return clipped_count_stddev, noise_multiplier_value
@@ -16,6 +16,7 @@
16
16
 
17
17
 
18
18
  KEY_CLIPPING_NORM = "clipping_norm"
19
+ KEY_NORM_BIT = "norm_bit"
19
20
  CLIENTS_DISCREPANCY_WARNING = (
20
21
  "The number of clients returning parameters (%s)"
21
22
  " differs from the number of sampled clients (%s)."
@@ -14,33 +14,55 @@
14
14
  # ==============================================================================
15
15
  """Constants for the SecAgg/SecAgg+ protocol."""
16
16
 
17
+
18
+ from __future__ import annotations
19
+
17
20
  RECORD_KEY_STATE = "secaggplus_state"
18
21
  RECORD_KEY_CONFIGS = "secaggplus_configs"
19
22
 
20
- # Names of stages
21
- STAGE_SETUP = "setup"
22
- STAGE_SHARE_KEYS = "share_keys"
23
- STAGE_COLLECT_MASKED_INPUT = "collect_masked_input"
24
- STAGE_UNMASK = "unmask"
25
- STAGES = (STAGE_SETUP, STAGE_SHARE_KEYS, STAGE_COLLECT_MASKED_INPUT, STAGE_UNMASK)
26
-
27
- # All valid keys in received/replied `named_values` dictionaries
28
- KEY_STAGE = "stage"
29
- KEY_SAMPLE_NUMBER = "sample_num"
30
- KEY_SECURE_ID = "secure_id"
31
- KEY_SHARE_NUMBER = "share_num"
32
- KEY_THRESHOLD = "threshold"
33
- KEY_CLIPPING_RANGE = "clipping_range"
34
- KEY_TARGET_RANGE = "target_range"
35
- KEY_MOD_RANGE = "mod_range"
36
- KEY_PUBLIC_KEY_1 = "pk1"
37
- KEY_PUBLIC_KEY_2 = "pk2"
38
- KEY_DESTINATION_LIST = "dsts"
39
- KEY_CIPHERTEXT_LIST = "ctxts"
40
- KEY_SOURCE_LIST = "srcs"
41
- KEY_PARAMETERS = "params"
42
- KEY_MASKED_PARAMETERS = "masked_params"
43
- KEY_ACTIVE_SECURE_ID_LIST = "active_sids"
44
- KEY_DEAD_SECURE_ID_LIST = "dead_sids"
45
- KEY_SECURE_ID_LIST = "sids"
46
- KEY_SHARE_LIST = "shares"
23
+
24
+ class Stage:
25
+ """Stages for the SecAgg+ protocol."""
26
+
27
+ SETUP = "setup"
28
+ SHARE_KEYS = "share_keys"
29
+ COLLECT_MASKED_INPUT = "collect_masked_input"
30
+ UNMASK = "unmask"
31
+ _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_INPUT, UNMASK)
32
+
33
+ @classmethod
34
+ def all(cls) -> tuple[str, str, str, str]:
35
+ """Return all stages."""
36
+ return cls._stages
37
+
38
+ def __new__(cls) -> Stage:
39
+ """Prevent instantiation."""
40
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
41
+
42
+
43
+ class Key:
44
+ """Keys for the configs in the ConfigsRecord."""
45
+
46
+ STAGE = "stage"
47
+ SAMPLE_NUMBER = "sample_num"
48
+ SECURE_ID = "secure_id"
49
+ SHARE_NUMBER = "share_num"
50
+ THRESHOLD = "threshold"
51
+ CLIPPING_RANGE = "clipping_range"
52
+ TARGET_RANGE = "target_range"
53
+ MOD_RANGE = "mod_range"
54
+ PUBLIC_KEY_1 = "pk1"
55
+ PUBLIC_KEY_2 = "pk2"
56
+ DESTINATION_LIST = "dsts"
57
+ CIPHERTEXT_LIST = "ctxts"
58
+ SOURCE_LIST = "srcs"
59
+ PARAMETERS = "params"
60
+ MASKED_PARAMETERS = "masked_params"
61
+ ACTIVE_SECURE_ID_LIST = "active_sids"
62
+ DEAD_SECURE_ID_LIST = "dead_sids"
63
+ SECURE_ID_LIST = "sids"
64
+ SHARE_LIST = "shares"
65
+
66
+ def __new__(cls) -> Key:
67
+ """Prevent instantiation."""
68
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
@@ -0,0 +1,26 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: flwr/proto/error.proto
4
+ # Protobuf Python Version: 4.25.0
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+
16
+
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/error.proto\x12\nflwr.proto\"%\n\x05\x45rror\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x12\x12\x0e\n\x06reason\x18\x02 \x01(\tb\x06proto3')
18
+
19
+ _globals = globals()
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.error_pb2', _globals)
22
+ if _descriptor._USE_C_DESCRIPTORS == False:
23
+ DESCRIPTOR._options = None
24
+ _globals['_ERROR']._serialized_start=38
25
+ _globals['_ERROR']._serialized_end=75
26
+ # @@protoc_insertion_point(module_scope)
@@ -0,0 +1,25 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
5
+ import builtins
6
+ import google.protobuf.descriptor
7
+ import google.protobuf.message
8
+ import typing
9
+ import typing_extensions
10
+
11
+ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
12
+
13
+ class Error(google.protobuf.message.Message):
14
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
15
+ CODE_FIELD_NUMBER: builtins.int
16
+ REASON_FIELD_NUMBER: builtins.int
17
+ code: builtins.int
18
+ reason: typing.Text
19
+ def __init__(self,
20
+ *,
21
+ code: builtins.int = ...,
22
+ reason: typing.Text = ...,
23
+ ) -> None: ...
24
+ def ClearField(self, field_name: typing_extensions.Literal["code",b"code","reason",b"reason"]) -> None: ...
25
+ global___Error = Error
@@ -0,0 +1,4 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+
@@ -0,0 +1,4 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
flwr/proto/task_pb2.py CHANGED
@@ -15,19 +15,20 @@ _sym_db = _symbol_database.Default()
15
15
  from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
16
  from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
17
17
  from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
18
+ from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2
18
19
 
19
20
 
20
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd4\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
21
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
21
22
 
22
23
  _globals = globals()
23
24
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
24
25
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals)
25
26
  if _descriptor._USE_C_DESCRIPTORS == False:
26
27
  DESCRIPTOR._options = None
27
- _globals['_TASK']._serialized_start=117
28
- _globals['_TASK']._serialized_end=329
29
- _globals['_TASKINS']._serialized_start=331
30
- _globals['_TASKINS']._serialized_end=423
31
- _globals['_TASKRES']._serialized_start=425
32
- _globals['_TASKRES']._serialized_end=517
28
+ _globals['_TASK']._serialized_start=141
29
+ _globals['_TASK']._serialized_end=387
30
+ _globals['_TASKINS']._serialized_start=389
31
+ _globals['_TASKINS']._serialized_end=481
32
+ _globals['_TASKRES']._serialized_start=483
33
+ _globals['_TASKRES']._serialized_end=575
33
34
  # @@protoc_insertion_point(module_scope)