flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (42) hide show
  1. flwr/client/mod/__init__.py +3 -2
  2. flwr/client/mod/centraldp_mods.py +63 -2
  3. flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
  4. flwr/common/differential_privacy.py +77 -0
  5. flwr/common/differential_privacy_constants.py +1 -0
  6. flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
  7. flwr/proto/error_pb2.py +26 -0
  8. flwr/proto/error_pb2.pyi +25 -0
  9. flwr/proto/error_pb2_grpc.py +4 -0
  10. flwr/proto/error_pb2_grpc.pyi +4 -0
  11. flwr/proto/task_pb2.py +8 -7
  12. flwr/proto/task_pb2.pyi +7 -2
  13. flwr/server/__init__.py +4 -0
  14. flwr/server/app.py +8 -31
  15. flwr/server/client_proxy.py +5 -0
  16. flwr/server/compat/__init__.py +2 -0
  17. flwr/server/compat/app.py +7 -88
  18. flwr/server/compat/app_utils.py +102 -0
  19. flwr/server/compat/driver_client_proxy.py +22 -10
  20. flwr/server/compat/legacy_context.py +55 -0
  21. flwr/server/run_serverapp.py +1 -1
  22. flwr/server/server.py +18 -8
  23. flwr/server/strategy/__init__.py +24 -14
  24. flwr/server/strategy/dp_adaptive_clipping.py +449 -0
  25. flwr/server/strategy/dp_fixed_clipping.py +5 -7
  26. flwr/server/superlink/driver/driver_grpc.py +54 -0
  27. flwr/server/superlink/driver/driver_servicer.py +4 -4
  28. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
  29. flwr/server/superlink/fleet/vce/__init__.py +1 -1
  30. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
  31. flwr/server/superlink/fleet/vce/vce_api.py +236 -16
  32. flwr/server/typing.py +1 -0
  33. flwr/server/workflow/__init__.py +22 -0
  34. flwr/server/workflow/default_workflows.py +357 -0
  35. flwr/simulation/__init__.py +3 -0
  36. flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
  37. flwr/simulation/run_simulation.py +177 -0
  38. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
  39. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
  40. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
  41. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
  42. {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
flwr/server/server.py CHANGED
@@ -89,7 +89,7 @@ class Server:
89
89
 
90
90
  # Initialize parameters
91
91
  log(INFO, "Initializing global parameters")
92
- self.parameters = self._get_initial_parameters(timeout=timeout)
92
+ self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout)
93
93
  log(INFO, "Evaluating initial parameters")
94
94
  res = self.strategy.evaluate(0, parameters=self.parameters)
95
95
  if res is not None:
@@ -185,6 +185,7 @@ class Server:
185
185
  client_instructions,
186
186
  max_workers=self.max_workers,
187
187
  timeout=timeout,
188
+ group_id=server_round,
188
189
  )
189
190
  log(
190
191
  DEBUG,
@@ -234,6 +235,7 @@ class Server:
234
235
  client_instructions=client_instructions,
235
236
  max_workers=self.max_workers,
236
237
  timeout=timeout,
238
+ group_id=server_round,
237
239
  )
238
240
  log(
239
241
  DEBUG,
@@ -264,7 +266,9 @@ class Server:
264
266
  timeout=timeout,
265
267
  )
266
268
 
267
- def _get_initial_parameters(self, timeout: Optional[float]) -> Parameters:
269
+ def _get_initial_parameters(
270
+ self, server_round: int, timeout: Optional[float]
271
+ ) -> Parameters:
268
272
  """Get initial parameters from one of the available clients."""
269
273
  # Server-side parameter initialization
270
274
  parameters: Optional[Parameters] = self.strategy.initialize_parameters(
@@ -278,7 +282,9 @@ class Server:
278
282
  log(INFO, "Requesting initial parameters from one random client")
279
283
  random_client = self._client_manager.sample(1)[0]
280
284
  ins = GetParametersIns(config={})
281
- get_parameters_res = random_client.get_parameters(ins=ins, timeout=timeout)
285
+ get_parameters_res = random_client.get_parameters(
286
+ ins=ins, timeout=timeout, group_id=server_round
287
+ )
282
288
  log(INFO, "Received initial parameters from one random client")
283
289
  return get_parameters_res.parameters
284
290
 
@@ -321,6 +327,7 @@ def reconnect_client(
321
327
  disconnect = client.reconnect(
322
328
  reconnect,
323
329
  timeout=timeout,
330
+ group_id=None,
324
331
  )
325
332
  return client, disconnect
326
333
 
@@ -329,11 +336,12 @@ def fit_clients(
329
336
  client_instructions: List[Tuple[ClientProxy, FitIns]],
330
337
  max_workers: Optional[int],
331
338
  timeout: Optional[float],
339
+ group_id: int,
332
340
  ) -> FitResultsAndFailures:
333
341
  """Refine parameters concurrently on all selected clients."""
334
342
  with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
335
343
  submitted_fs = {
336
- executor.submit(fit_client, client_proxy, ins, timeout)
344
+ executor.submit(fit_client, client_proxy, ins, timeout, group_id)
337
345
  for client_proxy, ins in client_instructions
338
346
  }
339
347
  finished_fs, _ = concurrent.futures.wait(
@@ -352,10 +360,10 @@ def fit_clients(
352
360
 
353
361
 
354
362
  def fit_client(
355
- client: ClientProxy, ins: FitIns, timeout: Optional[float]
363
+ client: ClientProxy, ins: FitIns, timeout: Optional[float], group_id: int
356
364
  ) -> Tuple[ClientProxy, FitRes]:
357
365
  """Refine parameters on a single client."""
358
- fit_res = client.fit(ins, timeout=timeout)
366
+ fit_res = client.fit(ins, timeout=timeout, group_id=group_id)
359
367
  return client, fit_res
360
368
 
361
369
 
@@ -388,11 +396,12 @@ def evaluate_clients(
388
396
  client_instructions: List[Tuple[ClientProxy, EvaluateIns]],
389
397
  max_workers: Optional[int],
390
398
  timeout: Optional[float],
399
+ group_id: int,
391
400
  ) -> EvaluateResultsAndFailures:
392
401
  """Evaluate parameters concurrently on all selected clients."""
393
402
  with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
394
403
  submitted_fs = {
395
- executor.submit(evaluate_client, client_proxy, ins, timeout)
404
+ executor.submit(evaluate_client, client_proxy, ins, timeout, group_id)
396
405
  for client_proxy, ins in client_instructions
397
406
  }
398
407
  finished_fs, _ = concurrent.futures.wait(
@@ -414,9 +423,10 @@ def evaluate_client(
414
423
  client: ClientProxy,
415
424
  ins: EvaluateIns,
416
425
  timeout: Optional[float],
426
+ group_id: int,
417
427
  ) -> Tuple[ClientProxy, EvaluateRes]:
418
428
  """Evaluate parameters on a single client."""
419
- evaluate_res = client.evaluate(ins, timeout=timeout)
429
+ evaluate_res = client.evaluate(ins, timeout=timeout, group_id=group_id)
420
430
  return client, evaluate_res
421
431
 
422
432
 
@@ -16,9 +16,17 @@
16
16
 
17
17
 
18
18
  from .bulyan import Bulyan as Bulyan
19
+ from .dp_adaptive_clipping import (
20
+ DifferentialPrivacyClientSideAdaptiveClipping as DifferentialPrivacyClientSideAdaptiveClipping,
21
+ )
22
+ from .dp_adaptive_clipping import (
23
+ DifferentialPrivacyServerSideAdaptiveClipping as DifferentialPrivacyServerSideAdaptiveClipping,
24
+ )
25
+ from .dp_fixed_clipping import (
26
+ DifferentialPrivacyClientSideFixedClipping as DifferentialPrivacyClientSideFixedClipping,
27
+ )
19
28
  from .dp_fixed_clipping import (
20
- DifferentialPrivacyClientSideFixedClipping,
21
- DifferentialPrivacyServerSideFixedClipping,
29
+ DifferentialPrivacyServerSideFixedClipping as DifferentialPrivacyServerSideFixedClipping,
22
30
  )
23
31
  from .dpfedavg_adaptive import DPFedAvgAdaptive as DPFedAvgAdaptive
24
32
  from .dpfedavg_fixed import DPFedAvgFixed as DPFedAvgFixed
@@ -41,26 +49,28 @@ from .qfedavg import QFedAvg as QFedAvg
41
49
  from .strategy import Strategy as Strategy
42
50
 
43
51
  __all__ = [
44
- "FaultTolerantFedAvg",
52
+ "Bulyan",
53
+ "DPFedAvgAdaptive",
54
+ "DPFedAvgFixed",
55
+ "DifferentialPrivacyClientSideAdaptiveClipping",
56
+ "DifferentialPrivacyServerSideAdaptiveClipping",
57
+ "DifferentialPrivacyClientSideFixedClipping",
58
+ "DifferentialPrivacyServerSideFixedClipping",
45
59
  "FedAdagrad",
46
60
  "FedAdam",
47
61
  "FedAvg",
48
- "FedXgbNnAvg",
49
- "FedXgbBagging",
50
- "FedXgbCyclic",
51
62
  "FedAvgAndroid",
52
63
  "FedAvgM",
64
+ "FedMedian",
53
65
  "FedOpt",
54
66
  "FedProx",
55
- "FedYogi",
56
- "QFedAvg",
57
- "FedMedian",
58
67
  "FedTrimmedAvg",
68
+ "FedXgbBagging",
69
+ "FedXgbCyclic",
70
+ "FedXgbNnAvg",
71
+ "FedYogi",
72
+ "FaultTolerantFedAvg",
59
73
  "Krum",
60
- "Bulyan",
61
- "DPFedAvgAdaptive",
62
- "DPFedAvgFixed",
74
+ "QFedAvg",
63
75
  "Strategy",
64
- "DifferentialPrivacyServerSideFixedClipping",
65
- "DifferentialPrivacyClientSideFixedClipping",
66
76
  ]
@@ -0,0 +1,449 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Central differential privacy with adaptive clipping.
16
+
17
+ Paper (Andrew et al.): https://arxiv.org/abs/1905.03871
18
+ """
19
+
20
+
21
+ import math
22
+ from logging import WARNING
23
+ from typing import Dict, List, Optional, Tuple, Union
24
+
25
+ import numpy as np
26
+
27
+ from flwr.common import (
28
+ EvaluateIns,
29
+ EvaluateRes,
30
+ FitIns,
31
+ FitRes,
32
+ NDArrays,
33
+ Parameters,
34
+ Scalar,
35
+ ndarrays_to_parameters,
36
+ parameters_to_ndarrays,
37
+ )
38
+ from flwr.common.differential_privacy import (
39
+ adaptive_clip_inputs_inplace,
40
+ add_gaussian_noise_to_params,
41
+ compute_adaptive_noise_params,
42
+ )
43
+ from flwr.common.differential_privacy_constants import (
44
+ CLIENTS_DISCREPANCY_WARNING,
45
+ KEY_CLIPPING_NORM,
46
+ KEY_NORM_BIT,
47
+ )
48
+ from flwr.common.logger import log
49
+ from flwr.server.client_manager import ClientManager
50
+ from flwr.server.client_proxy import ClientProxy
51
+ from flwr.server.strategy.strategy import Strategy
52
+
53
+
54
+ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
55
+ """Strategy wrapper for central DP with server-side adaptive clipping.
56
+
57
+ Parameters
58
+ ----------
59
+ strategy: Strategy
60
+ The strategy to which DP functionalities will be added by this wrapper.
61
+ noise_multiplier : float
62
+ The noise multiplier for the Gaussian mechanism for model updates.
63
+ num_sampled_clients : int
64
+ The number of clients that are sampled on each round.
65
+ initial_clipping_norm : float
66
+ The initial value of clipping norm. Deafults to 0.1.
67
+ Andrew et al. recommends to set to 0.1.
68
+ target_clipped_quantile : float
69
+ The desired quantile of updates which should be clipped. Defaults to 0.5.
70
+ clip_norm_lr : float
71
+ The learning rate for the clipping norm adaptation. Defaults to 0.2.
72
+ Andrew et al. recommends to set to 0.2.
73
+ clipped_count_stddev : float
74
+ The standard deviation of the noise added to the count of updates below the estimate.
75
+ Andrew et al. recommends to set to `expected_num_records/20`
76
+
77
+ Examples
78
+ --------
79
+ Create a strategy:
80
+
81
+ >>> strategy = fl.server.strategy.FedAvg( ... )
82
+
83
+ Wrap the strategy with the DifferentialPrivacyServerSideAdaptiveClipping wrapper
84
+
85
+ >>> dp_strategy = DifferentialPrivacyServerSideAdaptiveClipping(
86
+ >>> strategy, cfg.noise_multiplier, cfg.num_sampled_clients, ...
87
+ >>> )
88
+ """
89
+
90
+ # pylint: disable=too-many-arguments,too-many-instance-attributes
91
+ def __init__(
92
+ self,
93
+ strategy: Strategy,
94
+ noise_multiplier: float,
95
+ num_sampled_clients: int,
96
+ initial_clipping_norm: float = 0.1,
97
+ target_clipped_quantile: float = 0.5,
98
+ clip_norm_lr: float = 0.2,
99
+ clipped_count_stddev: Optional[float] = None,
100
+ ) -> None:
101
+ super().__init__()
102
+
103
+ if strategy is None:
104
+ raise ValueError("The passed strategy is None.")
105
+
106
+ if noise_multiplier < 0:
107
+ raise ValueError("The noise multiplier should be a non-negative value.")
108
+
109
+ if num_sampled_clients <= 0:
110
+ raise ValueError(
111
+ "The number of sampled clients should be a positive value."
112
+ )
113
+
114
+ if initial_clipping_norm <= 0:
115
+ raise ValueError("The initial clipping norm should be a positive value.")
116
+
117
+ if not 0 <= target_clipped_quantile <= 1:
118
+ raise ValueError(
119
+ "The target clipped quantile must be between 0 and 1 (inclusive)."
120
+ )
121
+
122
+ if clip_norm_lr <= 0:
123
+ raise ValueError("The learning rate must be positive.")
124
+
125
+ if clipped_count_stddev is not None:
126
+ if clipped_count_stddev < 0:
127
+ raise ValueError("The `clipped_count_stddev` must be non-negative.")
128
+
129
+ self.strategy = strategy
130
+ self.num_sampled_clients = num_sampled_clients
131
+ self.clipping_norm = initial_clipping_norm
132
+ self.target_clipped_quantile = target_clipped_quantile
133
+ self.clip_norm_lr = clip_norm_lr
134
+ (
135
+ self.clipped_count_stddev,
136
+ self.noise_multiplier,
137
+ ) = compute_adaptive_noise_params(
138
+ noise_multiplier,
139
+ num_sampled_clients,
140
+ clipped_count_stddev,
141
+ )
142
+
143
+ self.current_round_params: NDArrays = []
144
+
145
+ def __repr__(self) -> str:
146
+ """Compute a string representation of the strategy."""
147
+ rep = "Differential Privacy Strategy Wrapper (Server-Side Adaptive Clipping)"
148
+ return rep
149
+
150
+ def initialize_parameters(
151
+ self, client_manager: ClientManager
152
+ ) -> Optional[Parameters]:
153
+ """Initialize global model parameters using given strategy."""
154
+ return self.strategy.initialize_parameters(client_manager)
155
+
156
+ def configure_fit(
157
+ self, server_round: int, parameters: Parameters, client_manager: ClientManager
158
+ ) -> List[Tuple[ClientProxy, FitIns]]:
159
+ """Configure the next round of training."""
160
+ self.current_round_params = parameters_to_ndarrays(parameters)
161
+ return self.strategy.configure_fit(server_round, parameters, client_manager)
162
+
163
+ def configure_evaluate(
164
+ self, server_round: int, parameters: Parameters, client_manager: ClientManager
165
+ ) -> List[Tuple[ClientProxy, EvaluateIns]]:
166
+ """Configure the next round of evaluation."""
167
+ return self.strategy.configure_evaluate(
168
+ server_round, parameters, client_manager
169
+ )
170
+
171
+ def aggregate_fit(
172
+ self,
173
+ server_round: int,
174
+ results: List[Tuple[ClientProxy, FitRes]],
175
+ failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
176
+ ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
177
+ """Aggregate training results and update clip norms."""
178
+ if failures:
179
+ return None, {}
180
+
181
+ if len(results) != self.num_sampled_clients:
182
+ log(
183
+ WARNING,
184
+ CLIENTS_DISCREPANCY_WARNING,
185
+ len(results),
186
+ self.num_sampled_clients,
187
+ )
188
+
189
+ norm_bit_set_count = 0
190
+ for _, res in results:
191
+ param = parameters_to_ndarrays(res.parameters)
192
+ # Compute and clip update
193
+ model_update = [
194
+ np.subtract(x, y) for (x, y) in zip(param, self.current_round_params)
195
+ ]
196
+
197
+ norm_bit = adaptive_clip_inputs_inplace(model_update, self.clipping_norm)
198
+ norm_bit_set_count += norm_bit
199
+
200
+ for i, _ in enumerate(self.current_round_params):
201
+ param[i] = self.current_round_params[i] + model_update[i]
202
+ # Convert back to parameters
203
+ res.parameters = ndarrays_to_parameters(param)
204
+
205
+ # Noising the count
206
+ noised_norm_bit_set_count = float(
207
+ np.random.normal(norm_bit_set_count, self.clipped_count_stddev)
208
+ )
209
+ noised_norm_bit_set_fraction = noised_norm_bit_set_count / len(results)
210
+ # Geometric update
211
+ self.clipping_norm *= math.exp(
212
+ -self.clip_norm_lr
213
+ * (noised_norm_bit_set_fraction - self.target_clipped_quantile)
214
+ )
215
+
216
+ aggregated_params, metrics = self.strategy.aggregate_fit(
217
+ server_round, results, failures
218
+ )
219
+
220
+ # Add Gaussian noise to the aggregated parameters
221
+ if aggregated_params:
222
+ aggregated_params = add_gaussian_noise_to_params(
223
+ aggregated_params,
224
+ self.noise_multiplier,
225
+ self.clipping_norm,
226
+ self.num_sampled_clients,
227
+ )
228
+
229
+ return aggregated_params, metrics
230
+
231
+ def aggregate_evaluate(
232
+ self,
233
+ server_round: int,
234
+ results: List[Tuple[ClientProxy, EvaluateRes]],
235
+ failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
236
+ ) -> Tuple[Optional[float], Dict[str, Scalar]]:
237
+ """Aggregate evaluation losses using the given strategy."""
238
+ return self.strategy.aggregate_evaluate(server_round, results, failures)
239
+
240
+ def evaluate(
241
+ self, server_round: int, parameters: Parameters
242
+ ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
243
+ """Evaluate model parameters using an evaluation function from the strategy."""
244
+ return self.strategy.evaluate(server_round, parameters)
245
+
246
+
247
+ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
248
+ """Strategy wrapper for central DP with client-side adaptive clipping.
249
+
250
+ Use `adaptiveclipping_mod` modifier at the client side.
251
+
252
+ In comparison to `DifferentialPrivacyServerSideAdaptiveClipping`,
253
+ which performs clipping on the server-side, `DifferentialPrivacyClientSideAdaptiveClipping`
254
+ expects clipping to happen on the client-side, usually by using the built-in
255
+ `adaptiveclipping_mod`.
256
+
257
+ Parameters
258
+ ----------
259
+ strategy : Strategy
260
+ The strategy to which DP functionalities will be added by this wrapper.
261
+ noise_multiplier : float
262
+ The noise multiplier for the Gaussian mechanism for model updates.
263
+ num_sampled_clients : int
264
+ The number of clients that are sampled on each round.
265
+ initial_clipping_norm : float
266
+ The initial value of clipping norm. Deafults to 0.1.
267
+ Andrew et al. recommends to set to 0.1.
268
+ target_clipped_quantile : float
269
+ The desired quantile of updates which should be clipped. Defaults to 0.5.
270
+ clip_norm_lr : float
271
+ The learning rate for the clipping norm adaptation. Defaults to 0.2.
272
+ Andrew et al. recommends to set to 0.2.
273
+ clipped_count_stddev : float
274
+ The stddev of the noise added to the count of updates currently below the estimate.
275
+ Andrew et al. recommends to set to `expected_num_records/20`
276
+
277
+ Examples
278
+ --------
279
+ Create a strategy:
280
+
281
+ >>> strategy = fl.server.strategy.FedAvg(...)
282
+
283
+ Wrap the strategy with the `DifferentialPrivacyClientSideAdaptiveClipping` wrapper:
284
+
285
+ >>> DifferentialPrivacyClientSideAdaptiveClipping(
286
+ >>> strategy, cfg.noise_multiplier, cfg.num_sampled_clients
287
+ >>> )
288
+
289
+ On the client, add the `adaptiveclipping_mod` to the client-side mods:
290
+
291
+ >>> app = fl.client.ClientApp(
292
+ >>> client_fn=client_fn, mods=[adaptiveclipping_mod]
293
+ >>> )
294
+ """
295
+
296
+ # pylint: disable=too-many-arguments,too-many-instance-attributes
297
+ def __init__(
298
+ self,
299
+ strategy: Strategy,
300
+ noise_multiplier: float,
301
+ num_sampled_clients: int,
302
+ initial_clipping_norm: float = 0.1,
303
+ target_clipped_quantile: float = 0.5,
304
+ clip_norm_lr: float = 0.2,
305
+ clipped_count_stddev: Optional[float] = None,
306
+ ) -> None:
307
+ super().__init__()
308
+
309
+ if strategy is None:
310
+ raise ValueError("The passed strategy is None.")
311
+
312
+ if noise_multiplier < 0:
313
+ raise ValueError("The noise multiplier should be a non-negative value.")
314
+
315
+ if num_sampled_clients <= 0:
316
+ raise ValueError(
317
+ "The number of sampled clients should be a positive value."
318
+ )
319
+
320
+ if initial_clipping_norm <= 0:
321
+ raise ValueError("The initial clipping norm should be a positive value.")
322
+
323
+ if not 0 <= target_clipped_quantile <= 1:
324
+ raise ValueError(
325
+ "The target clipped quantile must be between 0 and 1 (inclusive)."
326
+ )
327
+
328
+ if clip_norm_lr <= 0:
329
+ raise ValueError("The learning rate must be positive.")
330
+
331
+ if clipped_count_stddev is not None and clipped_count_stddev < 0:
332
+ raise ValueError("The `clipped_count_stddev` must be non-negative.")
333
+
334
+ self.strategy = strategy
335
+ self.num_sampled_clients = num_sampled_clients
336
+ self.clipping_norm = initial_clipping_norm
337
+ self.target_clipped_quantile = target_clipped_quantile
338
+ self.clip_norm_lr = clip_norm_lr
339
+ (
340
+ self.clipped_count_stddev,
341
+ self.noise_multiplier,
342
+ ) = compute_adaptive_noise_params(
343
+ noise_multiplier,
344
+ num_sampled_clients,
345
+ clipped_count_stddev,
346
+ )
347
+
348
+ def __repr__(self) -> str:
349
+ """Compute a string representation of the strategy."""
350
+ rep = "Differential Privacy Strategy Wrapper (Client-Side Adaptive Clipping)"
351
+ return rep
352
+
353
+ def initialize_parameters(
354
+ self, client_manager: ClientManager
355
+ ) -> Optional[Parameters]:
356
+ """Initialize global model parameters using given strategy."""
357
+ return self.strategy.initialize_parameters(client_manager)
358
+
359
+ def configure_fit(
360
+ self, server_round: int, parameters: Parameters, client_manager: ClientManager
361
+ ) -> List[Tuple[ClientProxy, FitIns]]:
362
+ """Configure the next round of training."""
363
+ additional_config = {KEY_CLIPPING_NORM: self.clipping_norm}
364
+ inner_strategy_config_result = self.strategy.configure_fit(
365
+ server_round, parameters, client_manager
366
+ )
367
+ for _, fit_ins in inner_strategy_config_result:
368
+ fit_ins.config.update(additional_config)
369
+
370
+ return inner_strategy_config_result
371
+
372
+ def configure_evaluate(
373
+ self, server_round: int, parameters: Parameters, client_manager: ClientManager
374
+ ) -> List[Tuple[ClientProxy, EvaluateIns]]:
375
+ """Configure the next round of evaluation."""
376
+ return self.strategy.configure_evaluate(
377
+ server_round, parameters, client_manager
378
+ )
379
+
380
+ def aggregate_fit(
381
+ self,
382
+ server_round: int,
383
+ results: List[Tuple[ClientProxy, FitRes]],
384
+ failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
385
+ ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
386
+ """Aggregate training results and update clip norms."""
387
+ if failures:
388
+ return None, {}
389
+
390
+ if len(results) != self.num_sampled_clients:
391
+ log(
392
+ WARNING,
393
+ CLIENTS_DISCREPANCY_WARNING,
394
+ len(results),
395
+ self.num_sampled_clients,
396
+ )
397
+
398
+ aggregated_params, metrics = self.strategy.aggregate_fit(
399
+ server_round, results, failures
400
+ )
401
+ self._update_clip_norm(results)
402
+
403
+ # Add Gaussian noise to the aggregated parameters
404
+ if aggregated_params:
405
+ aggregated_params = add_gaussian_noise_to_params(
406
+ aggregated_params,
407
+ self.noise_multiplier,
408
+ self.clipping_norm,
409
+ self.num_sampled_clients,
410
+ )
411
+
412
+ return aggregated_params, metrics
413
+
414
+ def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None:
415
+ # Calculate the number of clients which set the norm indicator bit
416
+ norm_bit_set_count = 0
417
+ for client_proxy, fit_res in results:
418
+ if KEY_NORM_BIT not in fit_res.metrics:
419
+ raise KeyError(
420
+ f"{KEY_NORM_BIT} not returned by client with id {client_proxy.cid}."
421
+ )
422
+ if fit_res.metrics[KEY_NORM_BIT]:
423
+ norm_bit_set_count += 1
424
+ # Add noise to the count
425
+ noised_norm_bit_set_count = float(
426
+ np.random.normal(norm_bit_set_count, self.clipped_count_stddev)
427
+ )
428
+
429
+ noised_norm_bit_set_fraction = noised_norm_bit_set_count / len(results)
430
+ # Geometric update
431
+ self.clipping_norm *= math.exp(
432
+ -self.clip_norm_lr
433
+ * (noised_norm_bit_set_fraction - self.target_clipped_quantile)
434
+ )
435
+
436
+ def aggregate_evaluate(
437
+ self,
438
+ server_round: int,
439
+ results: List[Tuple[ClientProxy, EvaluateRes]],
440
+ failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
441
+ ) -> Tuple[Optional[float], Dict[str, Scalar]]:
442
+ """Aggregate evaluation losses using the given strategy."""
443
+ return self.strategy.aggregate_evaluate(server_round, results, failures)
444
+
445
+ def evaluate(
446
+ self, server_round: int, parameters: Parameters
447
+ ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
448
+ """Evaluate model parameters using an evaluation function from the strategy."""
449
+ return self.strategy.evaluate(server_round, parameters)
@@ -47,8 +47,7 @@ from flwr.server.strategy.strategy import Strategy
47
47
 
48
48
 
49
49
  class DifferentialPrivacyServerSideFixedClipping(Strategy):
50
- """Strategy wrapper for central differential privacy with server-side fixed
51
- clipping.
50
+ """Strategy wrapper for central DP with server-side fixed clipping.
52
51
 
53
52
  Parameters
54
53
  ----------
@@ -192,15 +191,14 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
192
191
 
193
192
 
194
193
  class DifferentialPrivacyClientSideFixedClipping(Strategy):
195
- """Strategy wrapper for central differential privacy with client-side fixed
196
- clipping.
194
+ """Strategy wrapper for central DP with client-side fixed clipping.
197
195
 
198
196
  Use `fixedclipping_mod` modifier at the client side.
199
197
 
200
198
  In comparison to `DifferentialPrivacyServerSideFixedClipping`,
201
199
  which performs clipping on the server-side, `DifferentialPrivacyClientSideFixedClipping`
202
200
  expects clipping to happen on the client-side, usually by using the built-in
203
- `fixedclipping_mod `.
201
+ `fixedclipping_mod`.
204
202
 
205
203
  Parameters
206
204
  ----------
@@ -220,7 +218,7 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
220
218
 
221
219
  >>> strategy = fl.server.strategy.FedAvg(...)
222
220
 
223
- Wrap the strategy with the `DifferentialPrivacyServerSideFixedClipping` wrapper:
221
+ Wrap the strategy with the `DifferentialPrivacyClientSideFixedClipping` wrapper:
224
222
 
225
223
  >>> DifferentialPrivacyClientSideFixedClipping(
226
224
  >>> strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients
@@ -229,7 +227,7 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
229
227
  On the client, add the `fixedclipping_mod` to the client-side mods:
230
228
 
231
229
  >>> app = fl.client.ClientApp(
232
- >>> client_fn=FlowerClient().to_client(), mods=[fixedclipping_mod]
230
+ >>> client_fn=client_fn, mods=[fixedclipping_mod]
233
231
  >>> )
234
232
  """
235
233