flwr-nightly 1.8.0.dev20240227__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/client/mod/__init__.py +3 -2
- flwr/client/mod/centraldp_mods.py +63 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
- flwr/common/differential_privacy.py +77 -0
- flwr/common/differential_privacy_constants.py +1 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
- flwr/proto/error_pb2.py +26 -0
- flwr/proto/error_pb2.pyi +25 -0
- flwr/proto/error_pb2_grpc.py +4 -0
- flwr/proto/error_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +8 -7
- flwr/proto/task_pb2.pyi +7 -2
- flwr/server/__init__.py +4 -0
- flwr/server/app.py +8 -31
- flwr/server/client_proxy.py +5 -0
- flwr/server/compat/__init__.py +2 -0
- flwr/server/compat/app.py +7 -88
- flwr/server/compat/app_utils.py +102 -0
- flwr/server/compat/driver_client_proxy.py +22 -10
- flwr/server/compat/legacy_context.py +55 -0
- flwr/server/run_serverapp.py +1 -1
- flwr/server/server.py +18 -8
- flwr/server/strategy/__init__.py +24 -14
- flwr/server/strategy/dp_adaptive_clipping.py +449 -0
- flwr/server/strategy/dp_fixed_clipping.py +5 -7
- flwr/server/superlink/driver/driver_grpc.py +54 -0
- flwr/server/superlink/driver/driver_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +5 -0
- flwr/server/superlink/fleet/vce/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -4
- flwr/server/superlink/fleet/vce/vce_api.py +236 -16
- flwr/server/typing.py +1 -0
- flwr/server/workflow/__init__.py +22 -0
- flwr/server/workflow/default_workflows.py +357 -0
- flwr/simulation/__init__.py +3 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +28 -8
- flwr/simulation/run_simulation.py +177 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +4 -3
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +42 -31
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240227.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
flwr/server/server.py
CHANGED
@@ -89,7 +89,7 @@ class Server:
|
|
89
89
|
|
90
90
|
# Initialize parameters
|
91
91
|
log(INFO, "Initializing global parameters")
|
92
|
-
self.parameters = self._get_initial_parameters(timeout=timeout)
|
92
|
+
self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout)
|
93
93
|
log(INFO, "Evaluating initial parameters")
|
94
94
|
res = self.strategy.evaluate(0, parameters=self.parameters)
|
95
95
|
if res is not None:
|
@@ -185,6 +185,7 @@ class Server:
|
|
185
185
|
client_instructions,
|
186
186
|
max_workers=self.max_workers,
|
187
187
|
timeout=timeout,
|
188
|
+
group_id=server_round,
|
188
189
|
)
|
189
190
|
log(
|
190
191
|
DEBUG,
|
@@ -234,6 +235,7 @@ class Server:
|
|
234
235
|
client_instructions=client_instructions,
|
235
236
|
max_workers=self.max_workers,
|
236
237
|
timeout=timeout,
|
238
|
+
group_id=server_round,
|
237
239
|
)
|
238
240
|
log(
|
239
241
|
DEBUG,
|
@@ -264,7 +266,9 @@ class Server:
|
|
264
266
|
timeout=timeout,
|
265
267
|
)
|
266
268
|
|
267
|
-
def _get_initial_parameters(
|
269
|
+
def _get_initial_parameters(
|
270
|
+
self, server_round: int, timeout: Optional[float]
|
271
|
+
) -> Parameters:
|
268
272
|
"""Get initial parameters from one of the available clients."""
|
269
273
|
# Server-side parameter initialization
|
270
274
|
parameters: Optional[Parameters] = self.strategy.initialize_parameters(
|
@@ -278,7 +282,9 @@ class Server:
|
|
278
282
|
log(INFO, "Requesting initial parameters from one random client")
|
279
283
|
random_client = self._client_manager.sample(1)[0]
|
280
284
|
ins = GetParametersIns(config={})
|
281
|
-
get_parameters_res = random_client.get_parameters(
|
285
|
+
get_parameters_res = random_client.get_parameters(
|
286
|
+
ins=ins, timeout=timeout, group_id=server_round
|
287
|
+
)
|
282
288
|
log(INFO, "Received initial parameters from one random client")
|
283
289
|
return get_parameters_res.parameters
|
284
290
|
|
@@ -321,6 +327,7 @@ def reconnect_client(
|
|
321
327
|
disconnect = client.reconnect(
|
322
328
|
reconnect,
|
323
329
|
timeout=timeout,
|
330
|
+
group_id=None,
|
324
331
|
)
|
325
332
|
return client, disconnect
|
326
333
|
|
@@ -329,11 +336,12 @@ def fit_clients(
|
|
329
336
|
client_instructions: List[Tuple[ClientProxy, FitIns]],
|
330
337
|
max_workers: Optional[int],
|
331
338
|
timeout: Optional[float],
|
339
|
+
group_id: int,
|
332
340
|
) -> FitResultsAndFailures:
|
333
341
|
"""Refine parameters concurrently on all selected clients."""
|
334
342
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
335
343
|
submitted_fs = {
|
336
|
-
executor.submit(fit_client, client_proxy, ins, timeout)
|
344
|
+
executor.submit(fit_client, client_proxy, ins, timeout, group_id)
|
337
345
|
for client_proxy, ins in client_instructions
|
338
346
|
}
|
339
347
|
finished_fs, _ = concurrent.futures.wait(
|
@@ -352,10 +360,10 @@ def fit_clients(
|
|
352
360
|
|
353
361
|
|
354
362
|
def fit_client(
|
355
|
-
client: ClientProxy, ins: FitIns, timeout: Optional[float]
|
363
|
+
client: ClientProxy, ins: FitIns, timeout: Optional[float], group_id: int
|
356
364
|
) -> Tuple[ClientProxy, FitRes]:
|
357
365
|
"""Refine parameters on a single client."""
|
358
|
-
fit_res = client.fit(ins, timeout=timeout)
|
366
|
+
fit_res = client.fit(ins, timeout=timeout, group_id=group_id)
|
359
367
|
return client, fit_res
|
360
368
|
|
361
369
|
|
@@ -388,11 +396,12 @@ def evaluate_clients(
|
|
388
396
|
client_instructions: List[Tuple[ClientProxy, EvaluateIns]],
|
389
397
|
max_workers: Optional[int],
|
390
398
|
timeout: Optional[float],
|
399
|
+
group_id: int,
|
391
400
|
) -> EvaluateResultsAndFailures:
|
392
401
|
"""Evaluate parameters concurrently on all selected clients."""
|
393
402
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
394
403
|
submitted_fs = {
|
395
|
-
executor.submit(evaluate_client, client_proxy, ins, timeout)
|
404
|
+
executor.submit(evaluate_client, client_proxy, ins, timeout, group_id)
|
396
405
|
for client_proxy, ins in client_instructions
|
397
406
|
}
|
398
407
|
finished_fs, _ = concurrent.futures.wait(
|
@@ -414,9 +423,10 @@ def evaluate_client(
|
|
414
423
|
client: ClientProxy,
|
415
424
|
ins: EvaluateIns,
|
416
425
|
timeout: Optional[float],
|
426
|
+
group_id: int,
|
417
427
|
) -> Tuple[ClientProxy, EvaluateRes]:
|
418
428
|
"""Evaluate parameters on a single client."""
|
419
|
-
evaluate_res = client.evaluate(ins, timeout=timeout)
|
429
|
+
evaluate_res = client.evaluate(ins, timeout=timeout, group_id=group_id)
|
420
430
|
return client, evaluate_res
|
421
431
|
|
422
432
|
|
flwr/server/strategy/__init__.py
CHANGED
@@ -16,9 +16,17 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from .bulyan import Bulyan as Bulyan
|
19
|
+
from .dp_adaptive_clipping import (
|
20
|
+
DifferentialPrivacyClientSideAdaptiveClipping as DifferentialPrivacyClientSideAdaptiveClipping,
|
21
|
+
)
|
22
|
+
from .dp_adaptive_clipping import (
|
23
|
+
DifferentialPrivacyServerSideAdaptiveClipping as DifferentialPrivacyServerSideAdaptiveClipping,
|
24
|
+
)
|
25
|
+
from .dp_fixed_clipping import (
|
26
|
+
DifferentialPrivacyClientSideFixedClipping as DifferentialPrivacyClientSideFixedClipping,
|
27
|
+
)
|
19
28
|
from .dp_fixed_clipping import (
|
20
|
-
|
21
|
-
DifferentialPrivacyServerSideFixedClipping,
|
29
|
+
DifferentialPrivacyServerSideFixedClipping as DifferentialPrivacyServerSideFixedClipping,
|
22
30
|
)
|
23
31
|
from .dpfedavg_adaptive import DPFedAvgAdaptive as DPFedAvgAdaptive
|
24
32
|
from .dpfedavg_fixed import DPFedAvgFixed as DPFedAvgFixed
|
@@ -41,26 +49,28 @@ from .qfedavg import QFedAvg as QFedAvg
|
|
41
49
|
from .strategy import Strategy as Strategy
|
42
50
|
|
43
51
|
__all__ = [
|
44
|
-
"
|
52
|
+
"Bulyan",
|
53
|
+
"DPFedAvgAdaptive",
|
54
|
+
"DPFedAvgFixed",
|
55
|
+
"DifferentialPrivacyClientSideAdaptiveClipping",
|
56
|
+
"DifferentialPrivacyServerSideAdaptiveClipping",
|
57
|
+
"DifferentialPrivacyClientSideFixedClipping",
|
58
|
+
"DifferentialPrivacyServerSideFixedClipping",
|
45
59
|
"FedAdagrad",
|
46
60
|
"FedAdam",
|
47
61
|
"FedAvg",
|
48
|
-
"FedXgbNnAvg",
|
49
|
-
"FedXgbBagging",
|
50
|
-
"FedXgbCyclic",
|
51
62
|
"FedAvgAndroid",
|
52
63
|
"FedAvgM",
|
64
|
+
"FedMedian",
|
53
65
|
"FedOpt",
|
54
66
|
"FedProx",
|
55
|
-
"FedYogi",
|
56
|
-
"QFedAvg",
|
57
|
-
"FedMedian",
|
58
67
|
"FedTrimmedAvg",
|
68
|
+
"FedXgbBagging",
|
69
|
+
"FedXgbCyclic",
|
70
|
+
"FedXgbNnAvg",
|
71
|
+
"FedYogi",
|
72
|
+
"FaultTolerantFedAvg",
|
59
73
|
"Krum",
|
60
|
-
"
|
61
|
-
"DPFedAvgAdaptive",
|
62
|
-
"DPFedAvgFixed",
|
74
|
+
"QFedAvg",
|
63
75
|
"Strategy",
|
64
|
-
"DifferentialPrivacyServerSideFixedClipping",
|
65
|
-
"DifferentialPrivacyClientSideFixedClipping",
|
66
76
|
]
|
@@ -0,0 +1,449 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Central differential privacy with adaptive clipping.
|
16
|
+
|
17
|
+
Paper (Andrew et al.): https://arxiv.org/abs/1905.03871
|
18
|
+
"""
|
19
|
+
|
20
|
+
|
21
|
+
import math
|
22
|
+
from logging import WARNING
|
23
|
+
from typing import Dict, List, Optional, Tuple, Union
|
24
|
+
|
25
|
+
import numpy as np
|
26
|
+
|
27
|
+
from flwr.common import (
|
28
|
+
EvaluateIns,
|
29
|
+
EvaluateRes,
|
30
|
+
FitIns,
|
31
|
+
FitRes,
|
32
|
+
NDArrays,
|
33
|
+
Parameters,
|
34
|
+
Scalar,
|
35
|
+
ndarrays_to_parameters,
|
36
|
+
parameters_to_ndarrays,
|
37
|
+
)
|
38
|
+
from flwr.common.differential_privacy import (
|
39
|
+
adaptive_clip_inputs_inplace,
|
40
|
+
add_gaussian_noise_to_params,
|
41
|
+
compute_adaptive_noise_params,
|
42
|
+
)
|
43
|
+
from flwr.common.differential_privacy_constants import (
|
44
|
+
CLIENTS_DISCREPANCY_WARNING,
|
45
|
+
KEY_CLIPPING_NORM,
|
46
|
+
KEY_NORM_BIT,
|
47
|
+
)
|
48
|
+
from flwr.common.logger import log
|
49
|
+
from flwr.server.client_manager import ClientManager
|
50
|
+
from flwr.server.client_proxy import ClientProxy
|
51
|
+
from flwr.server.strategy.strategy import Strategy
|
52
|
+
|
53
|
+
|
54
|
+
class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
55
|
+
"""Strategy wrapper for central DP with server-side adaptive clipping.
|
56
|
+
|
57
|
+
Parameters
|
58
|
+
----------
|
59
|
+
strategy: Strategy
|
60
|
+
The strategy to which DP functionalities will be added by this wrapper.
|
61
|
+
noise_multiplier : float
|
62
|
+
The noise multiplier for the Gaussian mechanism for model updates.
|
63
|
+
num_sampled_clients : int
|
64
|
+
The number of clients that are sampled on each round.
|
65
|
+
initial_clipping_norm : float
|
66
|
+
The initial value of clipping norm. Deafults to 0.1.
|
67
|
+
Andrew et al. recommends to set to 0.1.
|
68
|
+
target_clipped_quantile : float
|
69
|
+
The desired quantile of updates which should be clipped. Defaults to 0.5.
|
70
|
+
clip_norm_lr : float
|
71
|
+
The learning rate for the clipping norm adaptation. Defaults to 0.2.
|
72
|
+
Andrew et al. recommends to set to 0.2.
|
73
|
+
clipped_count_stddev : float
|
74
|
+
The standard deviation of the noise added to the count of updates below the estimate.
|
75
|
+
Andrew et al. recommends to set to `expected_num_records/20`
|
76
|
+
|
77
|
+
Examples
|
78
|
+
--------
|
79
|
+
Create a strategy:
|
80
|
+
|
81
|
+
>>> strategy = fl.server.strategy.FedAvg( ... )
|
82
|
+
|
83
|
+
Wrap the strategy with the DifferentialPrivacyServerSideAdaptiveClipping wrapper
|
84
|
+
|
85
|
+
>>> dp_strategy = DifferentialPrivacyServerSideAdaptiveClipping(
|
86
|
+
>>> strategy, cfg.noise_multiplier, cfg.num_sampled_clients, ...
|
87
|
+
>>> )
|
88
|
+
"""
|
89
|
+
|
90
|
+
# pylint: disable=too-many-arguments,too-many-instance-attributes
|
91
|
+
def __init__(
|
92
|
+
self,
|
93
|
+
strategy: Strategy,
|
94
|
+
noise_multiplier: float,
|
95
|
+
num_sampled_clients: int,
|
96
|
+
initial_clipping_norm: float = 0.1,
|
97
|
+
target_clipped_quantile: float = 0.5,
|
98
|
+
clip_norm_lr: float = 0.2,
|
99
|
+
clipped_count_stddev: Optional[float] = None,
|
100
|
+
) -> None:
|
101
|
+
super().__init__()
|
102
|
+
|
103
|
+
if strategy is None:
|
104
|
+
raise ValueError("The passed strategy is None.")
|
105
|
+
|
106
|
+
if noise_multiplier < 0:
|
107
|
+
raise ValueError("The noise multiplier should be a non-negative value.")
|
108
|
+
|
109
|
+
if num_sampled_clients <= 0:
|
110
|
+
raise ValueError(
|
111
|
+
"The number of sampled clients should be a positive value."
|
112
|
+
)
|
113
|
+
|
114
|
+
if initial_clipping_norm <= 0:
|
115
|
+
raise ValueError("The initial clipping norm should be a positive value.")
|
116
|
+
|
117
|
+
if not 0 <= target_clipped_quantile <= 1:
|
118
|
+
raise ValueError(
|
119
|
+
"The target clipped quantile must be between 0 and 1 (inclusive)."
|
120
|
+
)
|
121
|
+
|
122
|
+
if clip_norm_lr <= 0:
|
123
|
+
raise ValueError("The learning rate must be positive.")
|
124
|
+
|
125
|
+
if clipped_count_stddev is not None:
|
126
|
+
if clipped_count_stddev < 0:
|
127
|
+
raise ValueError("The `clipped_count_stddev` must be non-negative.")
|
128
|
+
|
129
|
+
self.strategy = strategy
|
130
|
+
self.num_sampled_clients = num_sampled_clients
|
131
|
+
self.clipping_norm = initial_clipping_norm
|
132
|
+
self.target_clipped_quantile = target_clipped_quantile
|
133
|
+
self.clip_norm_lr = clip_norm_lr
|
134
|
+
(
|
135
|
+
self.clipped_count_stddev,
|
136
|
+
self.noise_multiplier,
|
137
|
+
) = compute_adaptive_noise_params(
|
138
|
+
noise_multiplier,
|
139
|
+
num_sampled_clients,
|
140
|
+
clipped_count_stddev,
|
141
|
+
)
|
142
|
+
|
143
|
+
self.current_round_params: NDArrays = []
|
144
|
+
|
145
|
+
def __repr__(self) -> str:
|
146
|
+
"""Compute a string representation of the strategy."""
|
147
|
+
rep = "Differential Privacy Strategy Wrapper (Server-Side Adaptive Clipping)"
|
148
|
+
return rep
|
149
|
+
|
150
|
+
def initialize_parameters(
|
151
|
+
self, client_manager: ClientManager
|
152
|
+
) -> Optional[Parameters]:
|
153
|
+
"""Initialize global model parameters using given strategy."""
|
154
|
+
return self.strategy.initialize_parameters(client_manager)
|
155
|
+
|
156
|
+
def configure_fit(
|
157
|
+
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
158
|
+
) -> List[Tuple[ClientProxy, FitIns]]:
|
159
|
+
"""Configure the next round of training."""
|
160
|
+
self.current_round_params = parameters_to_ndarrays(parameters)
|
161
|
+
return self.strategy.configure_fit(server_round, parameters, client_manager)
|
162
|
+
|
163
|
+
def configure_evaluate(
|
164
|
+
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
165
|
+
) -> List[Tuple[ClientProxy, EvaluateIns]]:
|
166
|
+
"""Configure the next round of evaluation."""
|
167
|
+
return self.strategy.configure_evaluate(
|
168
|
+
server_round, parameters, client_manager
|
169
|
+
)
|
170
|
+
|
171
|
+
def aggregate_fit(
|
172
|
+
self,
|
173
|
+
server_round: int,
|
174
|
+
results: List[Tuple[ClientProxy, FitRes]],
|
175
|
+
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
|
176
|
+
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
|
177
|
+
"""Aggregate training results and update clip norms."""
|
178
|
+
if failures:
|
179
|
+
return None, {}
|
180
|
+
|
181
|
+
if len(results) != self.num_sampled_clients:
|
182
|
+
log(
|
183
|
+
WARNING,
|
184
|
+
CLIENTS_DISCREPANCY_WARNING,
|
185
|
+
len(results),
|
186
|
+
self.num_sampled_clients,
|
187
|
+
)
|
188
|
+
|
189
|
+
norm_bit_set_count = 0
|
190
|
+
for _, res in results:
|
191
|
+
param = parameters_to_ndarrays(res.parameters)
|
192
|
+
# Compute and clip update
|
193
|
+
model_update = [
|
194
|
+
np.subtract(x, y) for (x, y) in zip(param, self.current_round_params)
|
195
|
+
]
|
196
|
+
|
197
|
+
norm_bit = adaptive_clip_inputs_inplace(model_update, self.clipping_norm)
|
198
|
+
norm_bit_set_count += norm_bit
|
199
|
+
|
200
|
+
for i, _ in enumerate(self.current_round_params):
|
201
|
+
param[i] = self.current_round_params[i] + model_update[i]
|
202
|
+
# Convert back to parameters
|
203
|
+
res.parameters = ndarrays_to_parameters(param)
|
204
|
+
|
205
|
+
# Noising the count
|
206
|
+
noised_norm_bit_set_count = float(
|
207
|
+
np.random.normal(norm_bit_set_count, self.clipped_count_stddev)
|
208
|
+
)
|
209
|
+
noised_norm_bit_set_fraction = noised_norm_bit_set_count / len(results)
|
210
|
+
# Geometric update
|
211
|
+
self.clipping_norm *= math.exp(
|
212
|
+
-self.clip_norm_lr
|
213
|
+
* (noised_norm_bit_set_fraction - self.target_clipped_quantile)
|
214
|
+
)
|
215
|
+
|
216
|
+
aggregated_params, metrics = self.strategy.aggregate_fit(
|
217
|
+
server_round, results, failures
|
218
|
+
)
|
219
|
+
|
220
|
+
# Add Gaussian noise to the aggregated parameters
|
221
|
+
if aggregated_params:
|
222
|
+
aggregated_params = add_gaussian_noise_to_params(
|
223
|
+
aggregated_params,
|
224
|
+
self.noise_multiplier,
|
225
|
+
self.clipping_norm,
|
226
|
+
self.num_sampled_clients,
|
227
|
+
)
|
228
|
+
|
229
|
+
return aggregated_params, metrics
|
230
|
+
|
231
|
+
def aggregate_evaluate(
|
232
|
+
self,
|
233
|
+
server_round: int,
|
234
|
+
results: List[Tuple[ClientProxy, EvaluateRes]],
|
235
|
+
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
|
236
|
+
) -> Tuple[Optional[float], Dict[str, Scalar]]:
|
237
|
+
"""Aggregate evaluation losses using the given strategy."""
|
238
|
+
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
239
|
+
|
240
|
+
def evaluate(
|
241
|
+
self, server_round: int, parameters: Parameters
|
242
|
+
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
|
243
|
+
"""Evaluate model parameters using an evaluation function from the strategy."""
|
244
|
+
return self.strategy.evaluate(server_round, parameters)
|
245
|
+
|
246
|
+
|
247
|
+
class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
248
|
+
"""Strategy wrapper for central DP with client-side adaptive clipping.
|
249
|
+
|
250
|
+
Use `adaptiveclipping_mod` modifier at the client side.
|
251
|
+
|
252
|
+
In comparison to `DifferentialPrivacyServerSideAdaptiveClipping`,
|
253
|
+
which performs clipping on the server-side, `DifferentialPrivacyClientSideAdaptiveClipping`
|
254
|
+
expects clipping to happen on the client-side, usually by using the built-in
|
255
|
+
`adaptiveclipping_mod`.
|
256
|
+
|
257
|
+
Parameters
|
258
|
+
----------
|
259
|
+
strategy : Strategy
|
260
|
+
The strategy to which DP functionalities will be added by this wrapper.
|
261
|
+
noise_multiplier : float
|
262
|
+
The noise multiplier for the Gaussian mechanism for model updates.
|
263
|
+
num_sampled_clients : int
|
264
|
+
The number of clients that are sampled on each round.
|
265
|
+
initial_clipping_norm : float
|
266
|
+
The initial value of clipping norm. Deafults to 0.1.
|
267
|
+
Andrew et al. recommends to set to 0.1.
|
268
|
+
target_clipped_quantile : float
|
269
|
+
The desired quantile of updates which should be clipped. Defaults to 0.5.
|
270
|
+
clip_norm_lr : float
|
271
|
+
The learning rate for the clipping norm adaptation. Defaults to 0.2.
|
272
|
+
Andrew et al. recommends to set to 0.2.
|
273
|
+
clipped_count_stddev : float
|
274
|
+
The stddev of the noise added to the count of updates currently below the estimate.
|
275
|
+
Andrew et al. recommends to set to `expected_num_records/20`
|
276
|
+
|
277
|
+
Examples
|
278
|
+
--------
|
279
|
+
Create a strategy:
|
280
|
+
|
281
|
+
>>> strategy = fl.server.strategy.FedAvg(...)
|
282
|
+
|
283
|
+
Wrap the strategy with the `DifferentialPrivacyClientSideAdaptiveClipping` wrapper:
|
284
|
+
|
285
|
+
>>> DifferentialPrivacyClientSideAdaptiveClipping(
|
286
|
+
>>> strategy, cfg.noise_multiplier, cfg.num_sampled_clients
|
287
|
+
>>> )
|
288
|
+
|
289
|
+
On the client, add the `adaptiveclipping_mod` to the client-side mods:
|
290
|
+
|
291
|
+
>>> app = fl.client.ClientApp(
|
292
|
+
>>> client_fn=client_fn, mods=[adaptiveclipping_mod]
|
293
|
+
>>> )
|
294
|
+
"""
|
295
|
+
|
296
|
+
# pylint: disable=too-many-arguments,too-many-instance-attributes
|
297
|
+
def __init__(
|
298
|
+
self,
|
299
|
+
strategy: Strategy,
|
300
|
+
noise_multiplier: float,
|
301
|
+
num_sampled_clients: int,
|
302
|
+
initial_clipping_norm: float = 0.1,
|
303
|
+
target_clipped_quantile: float = 0.5,
|
304
|
+
clip_norm_lr: float = 0.2,
|
305
|
+
clipped_count_stddev: Optional[float] = None,
|
306
|
+
) -> None:
|
307
|
+
super().__init__()
|
308
|
+
|
309
|
+
if strategy is None:
|
310
|
+
raise ValueError("The passed strategy is None.")
|
311
|
+
|
312
|
+
if noise_multiplier < 0:
|
313
|
+
raise ValueError("The noise multiplier should be a non-negative value.")
|
314
|
+
|
315
|
+
if num_sampled_clients <= 0:
|
316
|
+
raise ValueError(
|
317
|
+
"The number of sampled clients should be a positive value."
|
318
|
+
)
|
319
|
+
|
320
|
+
if initial_clipping_norm <= 0:
|
321
|
+
raise ValueError("The initial clipping norm should be a positive value.")
|
322
|
+
|
323
|
+
if not 0 <= target_clipped_quantile <= 1:
|
324
|
+
raise ValueError(
|
325
|
+
"The target clipped quantile must be between 0 and 1 (inclusive)."
|
326
|
+
)
|
327
|
+
|
328
|
+
if clip_norm_lr <= 0:
|
329
|
+
raise ValueError("The learning rate must be positive.")
|
330
|
+
|
331
|
+
if clipped_count_stddev is not None and clipped_count_stddev < 0:
|
332
|
+
raise ValueError("The `clipped_count_stddev` must be non-negative.")
|
333
|
+
|
334
|
+
self.strategy = strategy
|
335
|
+
self.num_sampled_clients = num_sampled_clients
|
336
|
+
self.clipping_norm = initial_clipping_norm
|
337
|
+
self.target_clipped_quantile = target_clipped_quantile
|
338
|
+
self.clip_norm_lr = clip_norm_lr
|
339
|
+
(
|
340
|
+
self.clipped_count_stddev,
|
341
|
+
self.noise_multiplier,
|
342
|
+
) = compute_adaptive_noise_params(
|
343
|
+
noise_multiplier,
|
344
|
+
num_sampled_clients,
|
345
|
+
clipped_count_stddev,
|
346
|
+
)
|
347
|
+
|
348
|
+
def __repr__(self) -> str:
|
349
|
+
"""Compute a string representation of the strategy."""
|
350
|
+
rep = "Differential Privacy Strategy Wrapper (Client-Side Adaptive Clipping)"
|
351
|
+
return rep
|
352
|
+
|
353
|
+
def initialize_parameters(
|
354
|
+
self, client_manager: ClientManager
|
355
|
+
) -> Optional[Parameters]:
|
356
|
+
"""Initialize global model parameters using given strategy."""
|
357
|
+
return self.strategy.initialize_parameters(client_manager)
|
358
|
+
|
359
|
+
def configure_fit(
|
360
|
+
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
361
|
+
) -> List[Tuple[ClientProxy, FitIns]]:
|
362
|
+
"""Configure the next round of training."""
|
363
|
+
additional_config = {KEY_CLIPPING_NORM: self.clipping_norm}
|
364
|
+
inner_strategy_config_result = self.strategy.configure_fit(
|
365
|
+
server_round, parameters, client_manager
|
366
|
+
)
|
367
|
+
for _, fit_ins in inner_strategy_config_result:
|
368
|
+
fit_ins.config.update(additional_config)
|
369
|
+
|
370
|
+
return inner_strategy_config_result
|
371
|
+
|
372
|
+
def configure_evaluate(
|
373
|
+
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
374
|
+
) -> List[Tuple[ClientProxy, EvaluateIns]]:
|
375
|
+
"""Configure the next round of evaluation."""
|
376
|
+
return self.strategy.configure_evaluate(
|
377
|
+
server_round, parameters, client_manager
|
378
|
+
)
|
379
|
+
|
380
|
+
def aggregate_fit(
|
381
|
+
self,
|
382
|
+
server_round: int,
|
383
|
+
results: List[Tuple[ClientProxy, FitRes]],
|
384
|
+
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
|
385
|
+
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
|
386
|
+
"""Aggregate training results and update clip norms."""
|
387
|
+
if failures:
|
388
|
+
return None, {}
|
389
|
+
|
390
|
+
if len(results) != self.num_sampled_clients:
|
391
|
+
log(
|
392
|
+
WARNING,
|
393
|
+
CLIENTS_DISCREPANCY_WARNING,
|
394
|
+
len(results),
|
395
|
+
self.num_sampled_clients,
|
396
|
+
)
|
397
|
+
|
398
|
+
aggregated_params, metrics = self.strategy.aggregate_fit(
|
399
|
+
server_round, results, failures
|
400
|
+
)
|
401
|
+
self._update_clip_norm(results)
|
402
|
+
|
403
|
+
# Add Gaussian noise to the aggregated parameters
|
404
|
+
if aggregated_params:
|
405
|
+
aggregated_params = add_gaussian_noise_to_params(
|
406
|
+
aggregated_params,
|
407
|
+
self.noise_multiplier,
|
408
|
+
self.clipping_norm,
|
409
|
+
self.num_sampled_clients,
|
410
|
+
)
|
411
|
+
|
412
|
+
return aggregated_params, metrics
|
413
|
+
|
414
|
+
def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None:
|
415
|
+
# Calculate the number of clients which set the norm indicator bit
|
416
|
+
norm_bit_set_count = 0
|
417
|
+
for client_proxy, fit_res in results:
|
418
|
+
if KEY_NORM_BIT not in fit_res.metrics:
|
419
|
+
raise KeyError(
|
420
|
+
f"{KEY_NORM_BIT} not returned by client with id {client_proxy.cid}."
|
421
|
+
)
|
422
|
+
if fit_res.metrics[KEY_NORM_BIT]:
|
423
|
+
norm_bit_set_count += 1
|
424
|
+
# Add noise to the count
|
425
|
+
noised_norm_bit_set_count = float(
|
426
|
+
np.random.normal(norm_bit_set_count, self.clipped_count_stddev)
|
427
|
+
)
|
428
|
+
|
429
|
+
noised_norm_bit_set_fraction = noised_norm_bit_set_count / len(results)
|
430
|
+
# Geometric update
|
431
|
+
self.clipping_norm *= math.exp(
|
432
|
+
-self.clip_norm_lr
|
433
|
+
* (noised_norm_bit_set_fraction - self.target_clipped_quantile)
|
434
|
+
)
|
435
|
+
|
436
|
+
def aggregate_evaluate(
|
437
|
+
self,
|
438
|
+
server_round: int,
|
439
|
+
results: List[Tuple[ClientProxy, EvaluateRes]],
|
440
|
+
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
|
441
|
+
) -> Tuple[Optional[float], Dict[str, Scalar]]:
|
442
|
+
"""Aggregate evaluation losses using the given strategy."""
|
443
|
+
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
444
|
+
|
445
|
+
def evaluate(
|
446
|
+
self, server_round: int, parameters: Parameters
|
447
|
+
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
|
448
|
+
"""Evaluate model parameters using an evaluation function from the strategy."""
|
449
|
+
return self.strategy.evaluate(server_round, parameters)
|
@@ -47,8 +47,7 @@ from flwr.server.strategy.strategy import Strategy
|
|
47
47
|
|
48
48
|
|
49
49
|
class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
50
|
-
"""Strategy wrapper for central
|
51
|
-
clipping.
|
50
|
+
"""Strategy wrapper for central DP with server-side fixed clipping.
|
52
51
|
|
53
52
|
Parameters
|
54
53
|
----------
|
@@ -192,15 +191,14 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
|
192
191
|
|
193
192
|
|
194
193
|
class DifferentialPrivacyClientSideFixedClipping(Strategy):
|
195
|
-
"""Strategy wrapper for central
|
196
|
-
clipping.
|
194
|
+
"""Strategy wrapper for central DP with client-side fixed clipping.
|
197
195
|
|
198
196
|
Use `fixedclipping_mod` modifier at the client side.
|
199
197
|
|
200
198
|
In comparison to `DifferentialPrivacyServerSideFixedClipping`,
|
201
199
|
which performs clipping on the server-side, `DifferentialPrivacyClientSideFixedClipping`
|
202
200
|
expects clipping to happen on the client-side, usually by using the built-in
|
203
|
-
`fixedclipping_mod
|
201
|
+
`fixedclipping_mod`.
|
204
202
|
|
205
203
|
Parameters
|
206
204
|
----------
|
@@ -220,7 +218,7 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
|
|
220
218
|
|
221
219
|
>>> strategy = fl.server.strategy.FedAvg(...)
|
222
220
|
|
223
|
-
Wrap the strategy with the `
|
221
|
+
Wrap the strategy with the `DifferentialPrivacyClientSideFixedClipping` wrapper:
|
224
222
|
|
225
223
|
>>> DifferentialPrivacyClientSideFixedClipping(
|
226
224
|
>>> strategy, cfg.noise_multiplier, cfg.clipping_norm, cfg.num_sampled_clients
|
@@ -229,7 +227,7 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
|
|
229
227
|
On the client, add the `fixedclipping_mod` to the client-side mods:
|
230
228
|
|
231
229
|
>>> app = fl.client.ClientApp(
|
232
|
-
>>> client_fn=
|
230
|
+
>>> client_fn=client_fn, mods=[fixedclipping_mod]
|
233
231
|
>>> )
|
234
232
|
"""
|
235
233
|
|