flwr-nightly 1.22.0.dev20250917__py3-none-any.whl → 1.22.0.dev20250919__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/new.py +2 -0
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/clientapp/mod/__init__.py +2 -1
- flwr/clientapp/mod/centraldp_mods.py +155 -39
- flwr/clientapp/typing.py +22 -0
- flwr/common/constant.py +1 -0
- flwr/common/exit/exit_code.py +4 -0
- flwr/common/record/typeddict.py +12 -0
- flwr/serverapp/strategy/__init__.py +12 -0
- flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
- flwr/serverapp/strategy/fedadagrad.py +0 -3
- flwr/serverapp/strategy/fedadam.py +0 -3
- flwr/serverapp/strategy/fedavgm.py +3 -3
- flwr/serverapp/strategy/fedprox.py +1 -1
- flwr/serverapp/strategy/fedtrimmedavg.py +1 -1
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/serverapp/strategy/fedyogi.py +0 -3
- flwr/serverapp/strategy/krum.py +230 -0
- flwr/serverapp/strategy/qfedavg.py +252 -0
- flwr/supercore/cli/flower_superexec.py +26 -1
- flwr/supercore/constant.py +19 -0
- flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
- flwr/supercore/superexec/run_superexec.py +16 -2
- {flwr_nightly-1.22.0.dev20250917.dist-info → flwr_nightly-1.22.0.dev20250919.dist-info}/METADATA +1 -1
- {flwr_nightly-1.22.0.dev20250917.dist-info → flwr_nightly-1.22.0.dev20250919.dist-info}/RECORD +31 -23
- flwr/serverapp/dp_fixed_clipping.py +0 -352
- flwr/serverapp/strategy/strategy_utils_tests.py +0 -323
- {flwr_nightly-1.22.0.dev20250917.dist-info → flwr_nightly-1.22.0.dev20250919.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.22.0.dev20250917.dist-info → flwr_nightly-1.22.0.dev20250919.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,335 @@
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Message-based Central differential privacy with adaptive clipping.
|
16
|
+
|
17
|
+
Paper (Andrew et al.): https://arxiv.org/abs/1905.03871
|
18
|
+
"""
|
19
|
+
|
20
|
+
import math
|
21
|
+
from abc import ABC
|
22
|
+
from collections import OrderedDict
|
23
|
+
from collections.abc import Iterable
|
24
|
+
from logging import INFO
|
25
|
+
from typing import Optional
|
26
|
+
|
27
|
+
import numpy as np
|
28
|
+
|
29
|
+
from flwr.common import Array, ArrayRecord, ConfigRecord, Message, MetricRecord, log
|
30
|
+
from flwr.common.differential_privacy import (
|
31
|
+
adaptive_clip_inputs_inplace,
|
32
|
+
add_gaussian_noise_inplace,
|
33
|
+
compute_adaptive_noise_params,
|
34
|
+
compute_stdv,
|
35
|
+
)
|
36
|
+
from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM, KEY_NORM_BIT
|
37
|
+
from flwr.server import Grid
|
38
|
+
from flwr.serverapp.exception import AggregationError
|
39
|
+
|
40
|
+
from .dp_fixed_clipping import validate_replies
|
41
|
+
from .strategy import Strategy
|
42
|
+
|
43
|
+
|
44
|
+
class DifferentialPrivacyAdaptiveBase(Strategy, ABC):
|
45
|
+
"""Base class for DP strategies with adaptive clipping."""
|
46
|
+
|
47
|
+
# pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-positional-arguments
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
strategy: Strategy,
|
51
|
+
noise_multiplier: float,
|
52
|
+
num_sampled_clients: int,
|
53
|
+
initial_clipping_norm: float = 0.1,
|
54
|
+
target_clipped_quantile: float = 0.5,
|
55
|
+
clip_norm_lr: float = 0.2,
|
56
|
+
clipped_count_stddev: Optional[float] = None,
|
57
|
+
) -> None:
|
58
|
+
super().__init__()
|
59
|
+
|
60
|
+
if strategy is None:
|
61
|
+
raise ValueError("The passed strategy is None.")
|
62
|
+
if noise_multiplier < 0:
|
63
|
+
raise ValueError("The noise multiplier should be a non-negative value.")
|
64
|
+
if num_sampled_clients <= 0:
|
65
|
+
raise ValueError(
|
66
|
+
"The number of sampled clients should be a positive value."
|
67
|
+
)
|
68
|
+
if initial_clipping_norm <= 0:
|
69
|
+
raise ValueError("The initial clipping norm should be a positive value.")
|
70
|
+
if not 0 <= target_clipped_quantile <= 1:
|
71
|
+
raise ValueError("The target clipped quantile must be in [0, 1].")
|
72
|
+
if clip_norm_lr <= 0:
|
73
|
+
raise ValueError("The learning rate must be positive.")
|
74
|
+
if clipped_count_stddev is not None and clipped_count_stddev < 0:
|
75
|
+
raise ValueError("The `clipped_count_stddev` must be non-negative.")
|
76
|
+
|
77
|
+
self.strategy = strategy
|
78
|
+
self.num_sampled_clients = num_sampled_clients
|
79
|
+
self.clipping_norm = initial_clipping_norm
|
80
|
+
self.target_clipped_quantile = target_clipped_quantile
|
81
|
+
self.clip_norm_lr = clip_norm_lr
|
82
|
+
(
|
83
|
+
self.clipped_count_stddev,
|
84
|
+
self.noise_multiplier,
|
85
|
+
) = compute_adaptive_noise_params(
|
86
|
+
noise_multiplier,
|
87
|
+
num_sampled_clients,
|
88
|
+
clipped_count_stddev,
|
89
|
+
)
|
90
|
+
|
91
|
+
def _add_noise_to_aggregated_arrays(self, aggregated: ArrayRecord) -> ArrayRecord:
|
92
|
+
nds = aggregated.to_numpy_ndarrays()
|
93
|
+
stdv = compute_stdv(
|
94
|
+
self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
|
95
|
+
)
|
96
|
+
add_gaussian_noise_inplace(nds, stdv)
|
97
|
+
log(INFO, "aggregate_fit: central DP noise with %.4f stdev added", stdv)
|
98
|
+
return ArrayRecord(
|
99
|
+
OrderedDict({k: Array(v) for k, v in zip(aggregated.keys(), nds)})
|
100
|
+
)
|
101
|
+
|
102
|
+
def _noisy_fraction(self, count: int, total: int) -> float:
|
103
|
+
return float(np.random.normal(count, self.clipped_count_stddev)) / float(total)
|
104
|
+
|
105
|
+
def _geometric_update(self, clipped_fraction: float) -> None:
|
106
|
+
self.clipping_norm *= math.exp(
|
107
|
+
-self.clip_norm_lr * (clipped_fraction - self.target_clipped_quantile)
|
108
|
+
)
|
109
|
+
|
110
|
+
def configure_evaluate(
|
111
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
112
|
+
) -> Iterable[Message]:
|
113
|
+
"""Configure the next round of federated evaluation."""
|
114
|
+
return self.strategy.configure_evaluate(server_round, arrays, config, grid)
|
115
|
+
|
116
|
+
def aggregate_evaluate(
|
117
|
+
self, server_round: int, replies: Iterable[Message]
|
118
|
+
) -> Optional[MetricRecord]:
|
119
|
+
"""Aggregate MetricRecords in the received Messages."""
|
120
|
+
return self.strategy.aggregate_evaluate(server_round, replies)
|
121
|
+
|
122
|
+
def summary(self) -> None:
|
123
|
+
"""Log summary configuration of the strategy."""
|
124
|
+
self.strategy.summary()
|
125
|
+
|
126
|
+
|
127
|
+
class DifferentialPrivacyServerSideAdaptiveClipping(DifferentialPrivacyAdaptiveBase):
|
128
|
+
"""Message-based central DP with server-side adaptive clipping."""
|
129
|
+
|
130
|
+
# pylint: disable=too-many-arguments,too-many-locals,too-many-positional-arguments
|
131
|
+
def __init__(
|
132
|
+
self,
|
133
|
+
strategy: Strategy,
|
134
|
+
noise_multiplier: float,
|
135
|
+
num_sampled_clients: int,
|
136
|
+
initial_clipping_norm: float = 0.1,
|
137
|
+
target_clipped_quantile: float = 0.5,
|
138
|
+
clip_norm_lr: float = 0.2,
|
139
|
+
clipped_count_stddev: Optional[float] = None,
|
140
|
+
) -> None:
|
141
|
+
super().__init__(
|
142
|
+
strategy,
|
143
|
+
noise_multiplier,
|
144
|
+
num_sampled_clients,
|
145
|
+
initial_clipping_norm,
|
146
|
+
target_clipped_quantile,
|
147
|
+
clip_norm_lr,
|
148
|
+
clipped_count_stddev,
|
149
|
+
)
|
150
|
+
self.current_arrays: ArrayRecord = ArrayRecord()
|
151
|
+
|
152
|
+
def __repr__(self) -> str:
|
153
|
+
"""Compute a string representation of the strategy."""
|
154
|
+
return "Differential Privacy Strategy Wrapper (Server-Side Adaptive Clipping)"
|
155
|
+
|
156
|
+
def summary(self) -> None:
|
157
|
+
"""Log summary configuration of the strategy."""
|
158
|
+
log(INFO, "\t├──> DP settings:")
|
159
|
+
log(INFO, "\t│\t├── Noise multiplier: %s", self.noise_multiplier)
|
160
|
+
log(INFO, "\t│\t├── Clipping norm: %s", self.clipping_norm)
|
161
|
+
log(INFO, "\t│\t├── Target clipped quantile: %s", self.target_clipped_quantile)
|
162
|
+
log(INFO, "\t│\t└── Clip norm learning rate: %s", self.clip_norm_lr)
|
163
|
+
super().summary()
|
164
|
+
|
165
|
+
def configure_train(
|
166
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
167
|
+
) -> Iterable[Message]:
|
168
|
+
"""Configure the next round of training."""
|
169
|
+
self.current_arrays = arrays
|
170
|
+
return self.strategy.configure_train(server_round, arrays, config, grid)
|
171
|
+
|
172
|
+
def aggregate_train(
|
173
|
+
self, server_round: int, replies: Iterable[Message]
|
174
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
175
|
+
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
176
|
+
if not validate_replies(replies, self.num_sampled_clients):
|
177
|
+
return None, None
|
178
|
+
|
179
|
+
current_nd = self.current_arrays.to_numpy_ndarrays()
|
180
|
+
clipped_indicator_count = 0
|
181
|
+
replies_list = list(replies)
|
182
|
+
|
183
|
+
for reply in replies_list:
|
184
|
+
for arr_name, record in reply.content.array_records.items():
|
185
|
+
reply_nd = record.to_numpy_ndarrays()
|
186
|
+
model_update = [
|
187
|
+
np.subtract(x, y) for (x, y) in zip(reply_nd, current_nd)
|
188
|
+
]
|
189
|
+
norm_bit = adaptive_clip_inputs_inplace(
|
190
|
+
model_update, self.clipping_norm
|
191
|
+
)
|
192
|
+
clipped_indicator_count += int(norm_bit)
|
193
|
+
# reconstruct array using clipped contribution from current round
|
194
|
+
restored = [c + u for c, u in zip(current_nd, model_update)]
|
195
|
+
reply.content[arr_name] = ArrayRecord(
|
196
|
+
OrderedDict({k: Array(v) for k, v in zip(record.keys(), restored)})
|
197
|
+
)
|
198
|
+
log(
|
199
|
+
INFO,
|
200
|
+
"aggregate_train: arrays in `ArrayRecord` are clipped by value: %.4f.",
|
201
|
+
self.clipping_norm,
|
202
|
+
)
|
203
|
+
|
204
|
+
clipped_fraction = self._noisy_fraction(
|
205
|
+
clipped_indicator_count, len(replies_list)
|
206
|
+
)
|
207
|
+
self._geometric_update(clipped_fraction)
|
208
|
+
|
209
|
+
aggregated_arrays, aggregated_metrics = self.strategy.aggregate_train(
|
210
|
+
server_round, replies_list
|
211
|
+
)
|
212
|
+
|
213
|
+
if aggregated_arrays:
|
214
|
+
aggregated_arrays = self._add_noise_to_aggregated_arrays(aggregated_arrays)
|
215
|
+
|
216
|
+
return aggregated_arrays, aggregated_metrics
|
217
|
+
|
218
|
+
|
219
|
+
class DifferentialPrivacyClientSideAdaptiveClipping(DifferentialPrivacyAdaptiveBase):
|
220
|
+
"""Strategy wrapper for central DP with client-side adaptive clipping.
|
221
|
+
|
222
|
+
Use `adaptiveclipping_mod` modifier at the client side.
|
223
|
+
|
224
|
+
In comparison to `DifferentialPrivacyServerSideAdaptiveClipping`,
|
225
|
+
which performs clipping on the server-side,
|
226
|
+
`DifferentialPrivacyClientSideAdaptiveClipping`
|
227
|
+
expects clipping to happen on the client-side, usually by using the built-in
|
228
|
+
`adaptiveclipping_mod`.
|
229
|
+
|
230
|
+
Parameters
|
231
|
+
----------
|
232
|
+
strategy : Strategy
|
233
|
+
The strategy to which DP functionalities will be added by this wrapper.
|
234
|
+
noise_multiplier : float
|
235
|
+
The noise multiplier for the Gaussian mechanism for model updates.
|
236
|
+
num_sampled_clients : int
|
237
|
+
The number of clients that are sampled on each round.
|
238
|
+
initial_clipping_norm : float
|
239
|
+
The initial value of clipping norm. Defaults to 0.1.
|
240
|
+
Andrew et al. recommends to set to 0.1.
|
241
|
+
target_clipped_quantile : float
|
242
|
+
The desired quantile of updates which should be clipped. Defaults to 0.5.
|
243
|
+
clip_norm_lr : float
|
244
|
+
The learning rate for the clipping norm adaptation. Defaults to 0.2.
|
245
|
+
Andrew et al. recommends to set to 0.2.
|
246
|
+
clipped_count_stddev : float
|
247
|
+
The stddev of the noise added to the count of
|
248
|
+
updates currently below the estimate.
|
249
|
+
Andrew et al. recommends to set to `expected_num_records/20`
|
250
|
+
|
251
|
+
Examples
|
252
|
+
--------
|
253
|
+
Create a strategy::
|
254
|
+
|
255
|
+
strategy = fl.serverapp.FedAvg(...)
|
256
|
+
|
257
|
+
Wrap the strategy with the `DifferentialPrivacyClientSideAdaptiveClipping` wrapper::
|
258
|
+
|
259
|
+
dp_strategy = DifferentialPrivacyClientSideAdaptiveClipping(
|
260
|
+
strategy, cfg.noise_multiplier, cfg.num_sampled_clients, ...
|
261
|
+
)
|
262
|
+
|
263
|
+
On the client, add the `adaptiveclipping_mod` to the client-side mods::
|
264
|
+
|
265
|
+
app = fl.client.ClientApp(mods=[adaptiveclipping_mod])
|
266
|
+
"""
|
267
|
+
|
268
|
+
def __repr__(self) -> str:
|
269
|
+
"""Compute a string representation of the strategy."""
|
270
|
+
return "Differential Privacy Strategy Wrapper (Client-Side Adaptive Clipping)"
|
271
|
+
|
272
|
+
def summary(self) -> None:
|
273
|
+
"""Log summary configuration of the strategy."""
|
274
|
+
log(INFO, "\t├──> DP settings:")
|
275
|
+
log(INFO, "\t│\t├── Noise multiplier: %s", self.noise_multiplier)
|
276
|
+
log(INFO, "\t│\t├── Clipping norm: %s", self.clipping_norm)
|
277
|
+
log(INFO, "\t│\t├── Target clipped quantile: %s", self.target_clipped_quantile)
|
278
|
+
log(INFO, "\t│\t└── Clip norm learning rate: %s", self.clip_norm_lr)
|
279
|
+
super().summary()
|
280
|
+
|
281
|
+
def configure_train(
|
282
|
+
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
283
|
+
) -> Iterable[Message]:
|
284
|
+
"""Configure the next round of training."""
|
285
|
+
config[KEY_CLIPPING_NORM] = self.clipping_norm
|
286
|
+
return self.strategy.configure_train(server_round, arrays, config, grid)
|
287
|
+
|
288
|
+
def aggregate_train(
|
289
|
+
self, server_round: int, replies: Iterable[Message]
|
290
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
291
|
+
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
292
|
+
if not validate_replies(replies, self.num_sampled_clients):
|
293
|
+
return None, None
|
294
|
+
|
295
|
+
replies_list = list(replies)
|
296
|
+
|
297
|
+
# validate that KEY_NORM_BIT is present in all replies
|
298
|
+
for msg in replies_list:
|
299
|
+
for _, mrec in msg.content.metric_records.items():
|
300
|
+
if KEY_NORM_BIT not in mrec:
|
301
|
+
raise AggregationError(
|
302
|
+
f"KEY_NORM_BIT ('{KEY_NORM_BIT}') not found"
|
303
|
+
f" in MetricRecord or metrics for reply."
|
304
|
+
)
|
305
|
+
|
306
|
+
aggregated_arrays, aggregated_metrics = self.strategy.aggregate_train(
|
307
|
+
server_round, replies_list
|
308
|
+
)
|
309
|
+
|
310
|
+
self._update_clip_norm_from_replies(replies_list)
|
311
|
+
|
312
|
+
if aggregated_arrays:
|
313
|
+
aggregated_arrays = self._add_noise_to_aggregated_arrays(aggregated_arrays)
|
314
|
+
|
315
|
+
return aggregated_arrays, aggregated_metrics
|
316
|
+
|
317
|
+
def _update_clip_norm_from_replies(self, replies: list[Message]) -> None:
|
318
|
+
total = len(replies)
|
319
|
+
clipped_count = 0
|
320
|
+
|
321
|
+
for msg in replies:
|
322
|
+
# KEY_NORM_BIT is guaranteed to be present
|
323
|
+
for _, mrec in msg.content.metric_records.items():
|
324
|
+
if KEY_NORM_BIT in mrec:
|
325
|
+
clipped_count += int(bool(mrec[KEY_NORM_BIT]))
|
326
|
+
break
|
327
|
+
else:
|
328
|
+
# Check fallback location
|
329
|
+
if hasattr(msg.content, "metrics") and isinstance(
|
330
|
+
msg.content.metrics, dict
|
331
|
+
):
|
332
|
+
clipped_count += int(bool(msg.content.metrics[KEY_NORM_BIT]))
|
333
|
+
|
334
|
+
clipped_fraction = self._noisy_fraction(clipped_count, total)
|
335
|
+
self._geometric_update(clipped_fraction)
|
@@ -84,53 +84,6 @@ class DifferentialPrivacyFixedClippingBase(Strategy, ABC):
|
|
84
84
|
self.clipping_norm = clipping_norm
|
85
85
|
self.num_sampled_clients = num_sampled_clients
|
86
86
|
|
87
|
-
def _validate_replies(self, replies: Iterable[Message]) -> bool:
|
88
|
-
"""Validate replies and log errors/warnings.
|
89
|
-
|
90
|
-
Returns
|
91
|
-
-------
|
92
|
-
bool
|
93
|
-
True if replies are valid for aggregation, False otherwise.
|
94
|
-
"""
|
95
|
-
num_errors = 0
|
96
|
-
num_replies_with_content = 0
|
97
|
-
for msg in replies:
|
98
|
-
if msg.has_error():
|
99
|
-
log(
|
100
|
-
INFO,
|
101
|
-
"Received error in reply from node %d: %s",
|
102
|
-
msg.metadata.src_node_id,
|
103
|
-
msg.error,
|
104
|
-
)
|
105
|
-
num_errors += 1
|
106
|
-
else:
|
107
|
-
num_replies_with_content += 1
|
108
|
-
|
109
|
-
# Errors are not allowed
|
110
|
-
if num_errors:
|
111
|
-
log(
|
112
|
-
INFO,
|
113
|
-
"aggregate_train: Some clients reported errors. Skipping aggregation.",
|
114
|
-
)
|
115
|
-
return False
|
116
|
-
|
117
|
-
log(
|
118
|
-
INFO,
|
119
|
-
"aggregate_train: Received %s results and %s failures",
|
120
|
-
num_replies_with_content,
|
121
|
-
num_errors,
|
122
|
-
)
|
123
|
-
|
124
|
-
if num_replies_with_content != self.num_sampled_clients:
|
125
|
-
log(
|
126
|
-
WARNING,
|
127
|
-
CLIENTS_DISCREPANCY_WARNING,
|
128
|
-
num_replies_with_content,
|
129
|
-
self.num_sampled_clients,
|
130
|
-
)
|
131
|
-
|
132
|
-
return True
|
133
|
-
|
134
87
|
def _add_noise_to_aggregated_arrays(
|
135
88
|
self, aggregated_arrays: ArrayRecord
|
136
89
|
) -> ArrayRecord:
|
@@ -228,6 +181,13 @@ class DifferentialPrivacyServerSideFixedClipping(DifferentialPrivacyFixedClippin
|
|
228
181
|
"""Compute a string representation of the strategy."""
|
229
182
|
return "Differential Privacy Strategy Wrapper (Server-Side Fixed Clipping)"
|
230
183
|
|
184
|
+
def summary(self) -> None:
|
185
|
+
"""Log summary configuration of the strategy."""
|
186
|
+
log(INFO, "\t├──> DP settings:")
|
187
|
+
log(INFO, "\t│\t├── Noise multiplier: %s", self.noise_multiplier)
|
188
|
+
log(INFO, "\t│\t└── Clipping norm: %s", self.clipping_norm)
|
189
|
+
super().summary()
|
190
|
+
|
231
191
|
def configure_train(
|
232
192
|
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
233
193
|
) -> Iterable[Message]:
|
@@ -241,7 +201,7 @@ class DifferentialPrivacyServerSideFixedClipping(DifferentialPrivacyFixedClippin
|
|
241
201
|
replies: Iterable[Message],
|
242
202
|
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
243
203
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
244
|
-
if not self.
|
204
|
+
if not validate_replies(replies, self.num_sampled_clients):
|
245
205
|
return None, None
|
246
206
|
|
247
207
|
# Clip arrays in replies
|
@@ -322,6 +282,13 @@ class DifferentialPrivacyClientSideFixedClipping(DifferentialPrivacyFixedClippin
|
|
322
282
|
"""Compute a string representation of the strategy."""
|
323
283
|
return "Differential Privacy Strategy Wrapper (Client-Side Fixed Clipping)"
|
324
284
|
|
285
|
+
def summary(self) -> None:
|
286
|
+
"""Log summary configuration of the strategy."""
|
287
|
+
log(INFO, "\t├──> DP settings:")
|
288
|
+
log(INFO, "\t│\t├── Noise multiplier: %s", self.noise_multiplier)
|
289
|
+
log(INFO, "\t│\t└── Clipping norm: %s", self.clipping_norm)
|
290
|
+
super().summary()
|
291
|
+
|
325
292
|
def configure_train(
|
326
293
|
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
327
294
|
) -> Iterable[Message]:
|
@@ -337,7 +304,7 @@ class DifferentialPrivacyClientSideFixedClipping(DifferentialPrivacyFixedClippin
|
|
337
304
|
replies: Iterable[Message],
|
338
305
|
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
339
306
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
340
|
-
if not self.
|
307
|
+
if not validate_replies(replies, self.num_sampled_clients):
|
341
308
|
return None, None
|
342
309
|
|
343
310
|
# Aggregate
|
@@ -350,3 +317,58 @@ class DifferentialPrivacyClientSideFixedClipping(DifferentialPrivacyFixedClippin
|
|
350
317
|
aggregated_arrays = self._add_noise_to_aggregated_arrays(aggregated_arrays)
|
351
318
|
|
352
319
|
return aggregated_arrays, aggregated_metrics
|
320
|
+
|
321
|
+
|
322
|
+
def validate_replies(replies: Iterable[Message], num_sampled_clients: int) -> bool:
|
323
|
+
"""Validate replies and log errors/warnings.
|
324
|
+
|
325
|
+
Arguments
|
326
|
+
----------
|
327
|
+
replies : Iterable[Message]
|
328
|
+
The replies to validate.
|
329
|
+
num_sampled_clients : int
|
330
|
+
The expected number of sampled clients.
|
331
|
+
|
332
|
+
Returns
|
333
|
+
-------
|
334
|
+
bool
|
335
|
+
True if replies are valid for aggregation, False otherwise.
|
336
|
+
"""
|
337
|
+
num_errors = 0
|
338
|
+
num_replies_with_content = 0
|
339
|
+
for msg in replies:
|
340
|
+
if msg.has_error():
|
341
|
+
log(
|
342
|
+
INFO,
|
343
|
+
"Received error in reply from node %d: %s",
|
344
|
+
msg.metadata.src_node_id,
|
345
|
+
msg.error,
|
346
|
+
)
|
347
|
+
num_errors += 1
|
348
|
+
else:
|
349
|
+
num_replies_with_content += 1
|
350
|
+
|
351
|
+
# Errors are not allowed
|
352
|
+
if num_errors:
|
353
|
+
log(
|
354
|
+
INFO,
|
355
|
+
"aggregate_train: Some clients reported errors. Skipping aggregation.",
|
356
|
+
)
|
357
|
+
return False
|
358
|
+
|
359
|
+
log(
|
360
|
+
INFO,
|
361
|
+
"aggregate_train: Received %s results and %s failures",
|
362
|
+
num_replies_with_content,
|
363
|
+
num_errors,
|
364
|
+
)
|
365
|
+
|
366
|
+
if num_replies_with_content != num_sampled_clients:
|
367
|
+
log(
|
368
|
+
WARNING,
|
369
|
+
CLIENTS_DISCREPANCY_WARNING,
|
370
|
+
num_replies_with_content,
|
371
|
+
num_sampled_clients,
|
372
|
+
)
|
373
|
+
|
374
|
+
return True
|
@@ -153,9 +153,6 @@ class FedAdagrad(FedOpt):
|
|
153
153
|
for k, x in self.current_arrays.items()
|
154
154
|
}
|
155
155
|
|
156
|
-
# Update current arrays
|
157
|
-
self.current_arrays = new_arrays
|
158
|
-
|
159
156
|
return (
|
160
157
|
ArrayRecord(OrderedDict({k: Array(v) for k, v in new_arrays.items()})),
|
161
158
|
aggregated_metrics,
|
@@ -172,9 +172,6 @@ class FedAdam(FedOpt):
|
|
172
172
|
for k, x in self.current_arrays.items()
|
173
173
|
}
|
174
174
|
|
175
|
-
# Update current arrays
|
176
|
-
self.current_arrays = new_arrays
|
177
|
-
|
178
175
|
return (
|
179
176
|
ArrayRecord(OrderedDict({k: Array(v) for k, v in new_arrays.items()})),
|
180
177
|
aggregated_metrics,
|
@@ -126,9 +126,9 @@ class FedAvgM(FedAvg):
|
|
126
126
|
"""Log summary configuration of the strategy."""
|
127
127
|
opt_status = "ON" if self.server_opt else "OFF"
|
128
128
|
log(INFO, "\t├──> FedAvgM settings:")
|
129
|
-
log(INFO, "\t
|
130
|
-
log(INFO, "\t
|
131
|
-
log(INFO, "\t
|
129
|
+
log(INFO, "\t│\t├── Server optimization: %s", opt_status)
|
130
|
+
log(INFO, "\t│\t├── Server learning rate: %s", self.server_learning_rate)
|
131
|
+
log(INFO, "\t│\t└── Server Momentum: %s", self.server_momentum)
|
132
132
|
super().summary()
|
133
133
|
|
134
134
|
def configure_train(
|
@@ -162,7 +162,7 @@ class FedProx(FedAvg):
|
|
162
162
|
def summary(self) -> None:
|
163
163
|
"""Log summary configuration of the strategy."""
|
164
164
|
log(INFO, "\t├──> FedProx settings:")
|
165
|
-
log(INFO, "\t
|
165
|
+
log(INFO, "\t│\t└── Proximal mu: %s", self.proximal_mu)
|
166
166
|
super().summary()
|
167
167
|
|
168
168
|
def configure_train(
|
@@ -108,7 +108,7 @@ class FedTrimmedAvg(FedAvg):
|
|
108
108
|
def summary(self) -> None:
|
109
109
|
"""Log summary configuration of the strategy."""
|
110
110
|
log(INFO, "\t├──> FedTrimmedAvg settings:")
|
111
|
-
log(INFO, "\t
|
111
|
+
log(INFO, "\t│\t└── beta: %s", self.beta)
|
112
112
|
super().summary()
|
113
113
|
|
114
114
|
def aggregate_train(
|