flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/flower_toml.py +4 -48
- flwr/cli/new/new.py +6 -3
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
- flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
- flwr/cli/utils.py +14 -1
- flwr/client/app.py +39 -5
- flwr/client/client_app.py +1 -47
- flwr/client/mod/__init__.py +2 -1
- flwr/client/mod/secure_aggregation/__init__.py +2 -0
- flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
- flwr/common/grpc.py +3 -3
- flwr/common/logger.py +78 -15
- flwr/common/object_ref.py +140 -0
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
- flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
- flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
- flwr/server/compat/app.py +2 -1
- flwr/server/driver/grpc_driver.py +4 -4
- flwr/server/history.py +22 -15
- flwr/server/run_serverapp.py +22 -4
- flwr/server/server.py +27 -23
- flwr/server/server_app.py +1 -47
- flwr/server/server_config.py +9 -0
- flwr/server/strategy/fedavg.py +2 -0
- flwr/server/superlink/fleet/vce/vce_api.py +9 -2
- flwr/server/superlink/state/in_memory_state.py +34 -32
- flwr/server/workflow/__init__.py +3 -0
- flwr/server/workflow/constant.py +32 -0
- flwr/server/workflow/default_workflows.py +52 -57
- flwr/server/workflow/secure_aggregation/__init__.py +24 -0
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,676 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Workflow for the SecAgg+ protocol."""
|
16
|
+
|
17
|
+
|
18
|
+
import random
|
19
|
+
from dataclasses import dataclass, field
|
20
|
+
from logging import DEBUG, ERROR, INFO, WARN
|
21
|
+
from typing import Dict, List, Optional, Set, Union, cast
|
22
|
+
|
23
|
+
import flwr.common.recordset_compat as compat
|
24
|
+
from flwr.common import (
|
25
|
+
Code,
|
26
|
+
ConfigsRecord,
|
27
|
+
Context,
|
28
|
+
FitRes,
|
29
|
+
Message,
|
30
|
+
MessageType,
|
31
|
+
NDArrays,
|
32
|
+
RecordSet,
|
33
|
+
Status,
|
34
|
+
bytes_to_ndarray,
|
35
|
+
log,
|
36
|
+
ndarrays_to_parameters,
|
37
|
+
)
|
38
|
+
from flwr.common.secure_aggregation.crypto.shamir import combine_shares
|
39
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
40
|
+
bytes_to_private_key,
|
41
|
+
bytes_to_public_key,
|
42
|
+
generate_shared_key,
|
43
|
+
)
|
44
|
+
from flwr.common.secure_aggregation.ndarrays_arithmetic import (
|
45
|
+
factor_extract,
|
46
|
+
get_parameters_shape,
|
47
|
+
parameters_addition,
|
48
|
+
parameters_mod,
|
49
|
+
parameters_subtraction,
|
50
|
+
)
|
51
|
+
from flwr.common.secure_aggregation.quantization import dequantize
|
52
|
+
from flwr.common.secure_aggregation.secaggplus_constants import (
|
53
|
+
RECORD_KEY_CONFIGS,
|
54
|
+
Key,
|
55
|
+
Stage,
|
56
|
+
)
|
57
|
+
from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
|
58
|
+
from flwr.server.compat.driver_client_proxy import DriverClientProxy
|
59
|
+
from flwr.server.compat.legacy_context import LegacyContext
|
60
|
+
from flwr.server.driver import Driver
|
61
|
+
|
62
|
+
from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
|
63
|
+
from ..constant import Key as WorkflowKey
|
64
|
+
|
65
|
+
|
66
|
+
@dataclass
|
67
|
+
class WorkflowState: # pylint: disable=R0902
|
68
|
+
"""The state of the SecAgg+ protocol."""
|
69
|
+
|
70
|
+
nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict)
|
71
|
+
sampled_node_ids: Set[int] = field(default_factory=set)
|
72
|
+
active_node_ids: Set[int] = field(default_factory=set)
|
73
|
+
num_shares: int = 0
|
74
|
+
threshold: int = 0
|
75
|
+
clipping_range: float = 0.0
|
76
|
+
quantization_range: int = 0
|
77
|
+
mod_range: int = 0
|
78
|
+
max_weight: float = 0.0
|
79
|
+
nid_to_neighbours: Dict[int, Set[int]] = field(default_factory=dict)
|
80
|
+
nid_to_publickeys: Dict[int, List[bytes]] = field(default_factory=dict)
|
81
|
+
forward_srcs: Dict[int, List[int]] = field(default_factory=dict)
|
82
|
+
forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
|
83
|
+
aggregate_ndarrays: NDArrays = field(default_factory=list)
|
84
|
+
|
85
|
+
|
86
|
+
class SecAggPlusWorkflow:
|
87
|
+
"""The workflow for the SecAgg+ protocol.
|
88
|
+
|
89
|
+
The SecAgg+ protocol ensures the secure summation of integer vectors owned by
|
90
|
+
multiple parties, without accessing any individual integer vector. This workflow
|
91
|
+
allows the server to compute the weighted average of model parameters across all
|
92
|
+
clients, ensuring individual contributions remain private. This is achieved by
|
93
|
+
clients sending both, a weighting factor and a weighted version of the locally
|
94
|
+
updated parameters, both of which are masked for privacy. Specifically, each
|
95
|
+
client uploads "[w, w * params]" with masks, where weighting factor 'w' is the
|
96
|
+
number of examples ('num_examples') and 'params' represents the model parameters
|
97
|
+
('parameters') from the client's `FitRes`. The server then aggregates these
|
98
|
+
contributions to compute the weighted average of model parameters.
|
99
|
+
|
100
|
+
The protocol involves four main stages:
|
101
|
+
- 'setup': Send SecAgg+ configuration to clients and collect their public keys.
|
102
|
+
- 'share keys': Broadcast public keys among clients and collect encrypted secret
|
103
|
+
key shares.
|
104
|
+
- 'collect masked vectors': Forward encrypted secret key shares to target clients
|
105
|
+
and collect masked model parameters.
|
106
|
+
- 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
|
107
|
+
|
108
|
+
Only the aggregated model parameters are exposed and passed to
|
109
|
+
`Strategy.aggregate_fit`, ensuring individual data privacy.
|
110
|
+
|
111
|
+
Parameters
|
112
|
+
----------
|
113
|
+
num_shares : Union[int, float]
|
114
|
+
The number of shares into which each client's private key is split under
|
115
|
+
the SecAgg+ protocol. If specified as a float, it represents the proportion
|
116
|
+
of all selected clients, and the number of shares will be set dynamically in
|
117
|
+
the run time. A private key can be reconstructed from these shares, allowing
|
118
|
+
for the secure aggregation of model updates. Each client sends one share to
|
119
|
+
each of its neighbors while retaining one.
|
120
|
+
reconstruction_threshold : Union[int, float]
|
121
|
+
The minimum number of shares required to reconstruct a client's private key,
|
122
|
+
or, if specified as a float, it represents the proportion of the total number
|
123
|
+
of shares needed for reconstruction. This threshold ensures privacy by allowing
|
124
|
+
for the recovery of contributions from dropped clients during aggregation,
|
125
|
+
without compromising individual client data.
|
126
|
+
max_weight : Optional[float] (default: 1000.0)
|
127
|
+
The maximum value of the weight that can be assigned to any single client's
|
128
|
+
update during the weighted average calculation on the server side, e.g., in the
|
129
|
+
FedAvg algorithm.
|
130
|
+
clipping_range : float, optional (default: 8.0)
|
131
|
+
The range within which model parameters are clipped before quantization.
|
132
|
+
This parameter ensures each model parameter is bounded within
|
133
|
+
[-clipping_range, clipping_range], facilitating quantization.
|
134
|
+
quantization_range : int, optional (default: 4194304, this equals 2**22)
|
135
|
+
The size of the range into which floating-point model parameters are quantized,
|
136
|
+
mapping each parameter to an integer in [0, quantization_range-1]. This
|
137
|
+
facilitates cryptographic operations on the model updates.
|
138
|
+
modulus_range : int, optional (default: 4294967296, this equals 2**32)
|
139
|
+
The range of values from which random mask entries are uniformly sampled
|
140
|
+
([0, modulus_range-1]). `modulus_range` must be less than 4294967296.
|
141
|
+
Please use 2**n values for `modulus_range` to prevent overflow issues.
|
142
|
+
timeout : Optional[float] (default: None)
|
143
|
+
The timeout duration in seconds. If specified, the workflow will wait for
|
144
|
+
replies for this duration each time. If `None`, there is no time limit and
|
145
|
+
the workflow will wait until replies for all messages are received.
|
146
|
+
|
147
|
+
Notes
|
148
|
+
-----
|
149
|
+
- Generally, higher `num_shares` means more robust to dropouts while increasing the
|
150
|
+
computational costs; higher `reconstruction_threshold` means better privacy
|
151
|
+
guarantees but less tolerance to dropouts.
|
152
|
+
- Too large `max_weight` may compromise the precision of the quantization.
|
153
|
+
- `modulus_range` must be 2**n and larger than `quantization_range`.
|
154
|
+
- When `num_shares` is a float, it is interpreted as the proportion of all selected
|
155
|
+
clients, and hence the number of shares will be determined in the runtime. This
|
156
|
+
allows for dynamic adjustment based on the total number of participating clients.
|
157
|
+
- Similarly, when `reconstruction_threshold` is a float, it is interpreted as the
|
158
|
+
proportion of the number of shares needed for the reconstruction of a private key.
|
159
|
+
This feature enables flexibility in setting the security threshold relative to the
|
160
|
+
number of distributed shares.
|
161
|
+
- `num_shares`, `reconstruction_threshold`, and the quantization parameters
|
162
|
+
(`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in
|
163
|
+
balancing privacy, robustness, and efficiency within the SecAgg+ protocol.
|
164
|
+
"""
|
165
|
+
|
166
|
+
def __init__( # pylint: disable=R0913
|
167
|
+
self,
|
168
|
+
num_shares: Union[int, float],
|
169
|
+
reconstruction_threshold: Union[int, float],
|
170
|
+
*,
|
171
|
+
max_weight: float = 1000.0,
|
172
|
+
clipping_range: float = 8.0,
|
173
|
+
quantization_range: int = 4194304,
|
174
|
+
modulus_range: int = 4294967296,
|
175
|
+
timeout: Optional[float] = None,
|
176
|
+
) -> None:
|
177
|
+
self.num_shares = num_shares
|
178
|
+
self.reconstruction_threshold = reconstruction_threshold
|
179
|
+
self.max_weight = max_weight
|
180
|
+
self.clipping_range = clipping_range
|
181
|
+
self.quantization_range = quantization_range
|
182
|
+
self.modulus_range = modulus_range
|
183
|
+
self.timeout = timeout
|
184
|
+
|
185
|
+
self._check_init_params()
|
186
|
+
|
187
|
+
def __call__(self, driver: Driver, context: Context) -> None:
|
188
|
+
"""Run the SecAgg+ protocol."""
|
189
|
+
if not isinstance(context, LegacyContext):
|
190
|
+
raise TypeError(
|
191
|
+
f"Expect a LegacyContext, but get {type(context).__name__}."
|
192
|
+
)
|
193
|
+
state = WorkflowState()
|
194
|
+
|
195
|
+
steps = (
|
196
|
+
self.setup_stage,
|
197
|
+
self.share_keys_stage,
|
198
|
+
self.collect_masked_vectors_stage,
|
199
|
+
self.unmask_stage,
|
200
|
+
)
|
201
|
+
log(INFO, "Secure aggregation commencing.")
|
202
|
+
for step in steps:
|
203
|
+
if not step(driver, context, state):
|
204
|
+
log(INFO, "Secure aggregation halted.")
|
205
|
+
return
|
206
|
+
log(INFO, "Secure aggregation completed.")
|
207
|
+
|
208
|
+
def _check_init_params(self) -> None: # pylint: disable=R0912
|
209
|
+
# Check `num_shares`
|
210
|
+
if not isinstance(self.num_shares, (int, float)):
|
211
|
+
raise TypeError("`num_shares` must be of type int or float.")
|
212
|
+
if isinstance(self.num_shares, int):
|
213
|
+
if self.num_shares == 1:
|
214
|
+
self.num_shares = 1.0
|
215
|
+
elif self.num_shares <= 2:
|
216
|
+
raise ValueError("`num_shares` as an integer must be greater than 2.")
|
217
|
+
elif self.num_shares > self.modulus_range / self.quantization_range:
|
218
|
+
log(
|
219
|
+
WARN,
|
220
|
+
"A `num_shares` larger than `modulus_range / quantization_range` "
|
221
|
+
"will potentially cause overflow when computing the aggregated "
|
222
|
+
"model parameters.",
|
223
|
+
)
|
224
|
+
elif self.num_shares <= 0:
|
225
|
+
raise ValueError("`num_shares` as a float must be greater than 0.")
|
226
|
+
|
227
|
+
# Check `reconstruction_threshold`
|
228
|
+
if not isinstance(self.reconstruction_threshold, (int, float)):
|
229
|
+
raise TypeError("`reconstruction_threshold` must be of type int or float.")
|
230
|
+
if isinstance(self.reconstruction_threshold, int):
|
231
|
+
if self.reconstruction_threshold == 1:
|
232
|
+
self.reconstruction_threshold = 1.0
|
233
|
+
elif isinstance(self.num_shares, int):
|
234
|
+
if self.reconstruction_threshold >= self.num_shares:
|
235
|
+
raise ValueError(
|
236
|
+
"`reconstruction_threshold` must be less than `num_shares`."
|
237
|
+
)
|
238
|
+
else:
|
239
|
+
if not 0 < self.reconstruction_threshold <= 1:
|
240
|
+
raise ValueError(
|
241
|
+
"If `reconstruction_threshold` is a float, "
|
242
|
+
"it must be greater than 0 and less than or equal to 1."
|
243
|
+
)
|
244
|
+
|
245
|
+
# Check `max_weight`
|
246
|
+
if self.max_weight <= 0:
|
247
|
+
raise ValueError("`max_weight` must be greater than 0.")
|
248
|
+
|
249
|
+
# Check `quantization_range`
|
250
|
+
if self.quantization_range <= 0:
|
251
|
+
raise ValueError("`quantization_range` must be greater than 0.")
|
252
|
+
|
253
|
+
# Check `quantization_range`
|
254
|
+
if not isinstance(self.quantization_range, int) or self.quantization_range <= 0:
|
255
|
+
raise ValueError(
|
256
|
+
"`quantization_range` must be an integer and greater than 0."
|
257
|
+
)
|
258
|
+
|
259
|
+
# Check `modulus_range`
|
260
|
+
if (
|
261
|
+
not isinstance(self.modulus_range, int)
|
262
|
+
or self.modulus_range <= self.quantization_range
|
263
|
+
):
|
264
|
+
raise ValueError(
|
265
|
+
"`modulus_range` must be an integer and "
|
266
|
+
"greater than `quantization_range`."
|
267
|
+
)
|
268
|
+
if bin(self.modulus_range).count("1") != 1:
|
269
|
+
raise ValueError("`modulus_range` must be a power of 2.")
|
270
|
+
|
271
|
+
def _check_threshold(self, state: WorkflowState) -> bool:
|
272
|
+
for node_id in state.sampled_node_ids:
|
273
|
+
active_neighbors = state.nid_to_neighbours[node_id] & state.active_node_ids
|
274
|
+
if len(active_neighbors) < state.threshold:
|
275
|
+
log(ERROR, "Insufficient available nodes.")
|
276
|
+
return False
|
277
|
+
return True
|
278
|
+
|
279
|
+
def setup_stage( # pylint: disable=R0912, R0914, R0915
|
280
|
+
self, driver: Driver, context: LegacyContext, state: WorkflowState
|
281
|
+
) -> bool:
|
282
|
+
"""Execute the 'setup' stage."""
|
283
|
+
# Obtain fit instructions
|
284
|
+
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
285
|
+
current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
|
286
|
+
parameters = compat.parametersrecord_to_parameters(
|
287
|
+
context.state.parameters_records[MAIN_PARAMS_RECORD],
|
288
|
+
keep_input=True,
|
289
|
+
)
|
290
|
+
proxy_fitins_lst = context.strategy.configure_fit(
|
291
|
+
current_round, parameters, context.client_manager
|
292
|
+
)
|
293
|
+
if not proxy_fitins_lst:
|
294
|
+
log(INFO, "configure_fit: no clients selected, cancel")
|
295
|
+
return False
|
296
|
+
log(
|
297
|
+
INFO,
|
298
|
+
"configure_fit: strategy sampled %s clients (out of %s)",
|
299
|
+
len(proxy_fitins_lst),
|
300
|
+
context.client_manager.num_available(),
|
301
|
+
)
|
302
|
+
|
303
|
+
state.nid_to_fitins = {
|
304
|
+
proxy.node_id: compat.fitins_to_recordset(fitins, False)
|
305
|
+
for proxy, fitins in proxy_fitins_lst
|
306
|
+
}
|
307
|
+
|
308
|
+
# Protocol config
|
309
|
+
sampled_node_ids = list(state.nid_to_fitins.keys())
|
310
|
+
num_samples = len(sampled_node_ids)
|
311
|
+
if num_samples < 2:
|
312
|
+
log(ERROR, "The number of samples should be greater than 1.")
|
313
|
+
return False
|
314
|
+
if isinstance(self.num_shares, float):
|
315
|
+
state.num_shares = round(self.num_shares * num_samples)
|
316
|
+
# If even
|
317
|
+
if state.num_shares < num_samples and state.num_shares & 1 == 0:
|
318
|
+
state.num_shares += 1
|
319
|
+
# If too small
|
320
|
+
if state.num_shares <= 2:
|
321
|
+
state.num_shares = num_samples
|
322
|
+
else:
|
323
|
+
state.num_shares = self.num_shares
|
324
|
+
if isinstance(self.reconstruction_threshold, float):
|
325
|
+
state.threshold = round(self.reconstruction_threshold * state.num_shares)
|
326
|
+
# Avoid too small threshold
|
327
|
+
state.threshold = max(state.threshold, 2)
|
328
|
+
else:
|
329
|
+
state.threshold = self.reconstruction_threshold
|
330
|
+
state.active_node_ids = set(sampled_node_ids)
|
331
|
+
state.clipping_range = self.clipping_range
|
332
|
+
state.quantization_range = self.quantization_range
|
333
|
+
state.mod_range = self.modulus_range
|
334
|
+
state.max_weight = self.max_weight
|
335
|
+
sa_params_dict = {
|
336
|
+
Key.STAGE: Stage.SETUP,
|
337
|
+
Key.SAMPLE_NUMBER: num_samples,
|
338
|
+
Key.SHARE_NUMBER: state.num_shares,
|
339
|
+
Key.THRESHOLD: state.threshold,
|
340
|
+
Key.CLIPPING_RANGE: state.clipping_range,
|
341
|
+
Key.TARGET_RANGE: state.quantization_range,
|
342
|
+
Key.MOD_RANGE: state.mod_range,
|
343
|
+
Key.MAX_WEIGHT: state.max_weight,
|
344
|
+
}
|
345
|
+
|
346
|
+
# The number of shares should better be odd in the SecAgg+ protocol.
|
347
|
+
if num_samples != state.num_shares and state.num_shares & 1 == 0:
|
348
|
+
log(WARN, "Number of shares in the SecAgg+ protocol should be odd.")
|
349
|
+
state.num_shares += 1
|
350
|
+
|
351
|
+
# Shuffle node IDs
|
352
|
+
random.shuffle(sampled_node_ids)
|
353
|
+
# Build neighbour relations (node ID -> secure IDs of neighbours)
|
354
|
+
half_share = state.num_shares >> 1
|
355
|
+
state.nid_to_neighbours = {
|
356
|
+
nid: {
|
357
|
+
sampled_node_ids[(idx + offset) % num_samples]
|
358
|
+
for offset in range(-half_share, half_share + 1)
|
359
|
+
}
|
360
|
+
for idx, nid in enumerate(sampled_node_ids)
|
361
|
+
}
|
362
|
+
|
363
|
+
state.sampled_node_ids = state.active_node_ids
|
364
|
+
|
365
|
+
# Send setup configuration to clients
|
366
|
+
cfgs_record = ConfigsRecord(sa_params_dict) # type: ignore
|
367
|
+
content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
|
368
|
+
|
369
|
+
def make(nid: int) -> Message:
|
370
|
+
return driver.create_message(
|
371
|
+
content=content,
|
372
|
+
message_type=MessageType.TRAIN,
|
373
|
+
dst_node_id=nid,
|
374
|
+
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
375
|
+
ttl="",
|
376
|
+
)
|
377
|
+
|
378
|
+
log(
|
379
|
+
DEBUG,
|
380
|
+
"[Stage 0] Sending configurations to %s clients.",
|
381
|
+
len(state.active_node_ids),
|
382
|
+
)
|
383
|
+
msgs = driver.send_and_receive(
|
384
|
+
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
385
|
+
)
|
386
|
+
state.active_node_ids = {
|
387
|
+
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
388
|
+
}
|
389
|
+
log(
|
390
|
+
DEBUG,
|
391
|
+
"[Stage 0] Received public keys from %s clients.",
|
392
|
+
len(state.active_node_ids),
|
393
|
+
)
|
394
|
+
|
395
|
+
for msg in msgs:
|
396
|
+
if msg.has_error():
|
397
|
+
continue
|
398
|
+
key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
399
|
+
node_id = msg.metadata.src_node_id
|
400
|
+
pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2]
|
401
|
+
state.nid_to_publickeys[node_id] = [cast(bytes, pk1), cast(bytes, pk2)]
|
402
|
+
|
403
|
+
return self._check_threshold(state)
|
404
|
+
|
405
|
+
def share_keys_stage( # pylint: disable=R0914
|
406
|
+
self, driver: Driver, context: LegacyContext, state: WorkflowState
|
407
|
+
) -> bool:
|
408
|
+
"""Execute the 'share keys' stage."""
|
409
|
+
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
410
|
+
|
411
|
+
def make(nid: int) -> Message:
|
412
|
+
neighbours = state.nid_to_neighbours[nid] & state.active_node_ids
|
413
|
+
cfgs_record = ConfigsRecord(
|
414
|
+
{str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
|
415
|
+
)
|
416
|
+
cfgs_record[Key.STAGE] = Stage.SHARE_KEYS
|
417
|
+
content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
|
418
|
+
return driver.create_message(
|
419
|
+
content=content,
|
420
|
+
message_type=MessageType.TRAIN,
|
421
|
+
dst_node_id=nid,
|
422
|
+
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
423
|
+
ttl="",
|
424
|
+
)
|
425
|
+
|
426
|
+
# Broadcast public keys to clients and receive secret key shares
|
427
|
+
log(
|
428
|
+
DEBUG,
|
429
|
+
"[Stage 1] Forwarding public keys to %s clients.",
|
430
|
+
len(state.active_node_ids),
|
431
|
+
)
|
432
|
+
msgs = driver.send_and_receive(
|
433
|
+
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
434
|
+
)
|
435
|
+
state.active_node_ids = {
|
436
|
+
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
437
|
+
}
|
438
|
+
log(
|
439
|
+
DEBUG,
|
440
|
+
"[Stage 1] Received encrypted key shares from %s clients.",
|
441
|
+
len(state.active_node_ids),
|
442
|
+
)
|
443
|
+
|
444
|
+
# Build forward packet list dictionary
|
445
|
+
srcs: List[int] = []
|
446
|
+
dsts: List[int] = []
|
447
|
+
ciphertexts: List[bytes] = []
|
448
|
+
fwd_ciphertexts: Dict[int, List[bytes]] = {
|
449
|
+
nid: [] for nid in state.active_node_ids
|
450
|
+
} # dest node ID -> list of ciphertexts
|
451
|
+
fwd_srcs: Dict[int, List[int]] = {
|
452
|
+
nid: [] for nid in state.active_node_ids
|
453
|
+
} # dest node ID -> list of src node IDs
|
454
|
+
for msg in msgs:
|
455
|
+
node_id = msg.metadata.src_node_id
|
456
|
+
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
457
|
+
dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
|
458
|
+
ctxt_lst = cast(List[bytes], res_dict[Key.CIPHERTEXT_LIST])
|
459
|
+
srcs += [node_id] * len(dst_lst)
|
460
|
+
dsts += dst_lst
|
461
|
+
ciphertexts += ctxt_lst
|
462
|
+
|
463
|
+
for src, dst, ciphertext in zip(srcs, dsts, ciphertexts):
|
464
|
+
if dst in fwd_ciphertexts:
|
465
|
+
fwd_ciphertexts[dst].append(ciphertext)
|
466
|
+
fwd_srcs[dst].append(src)
|
467
|
+
|
468
|
+
state.forward_srcs = fwd_srcs
|
469
|
+
state.forward_ciphertexts = fwd_ciphertexts
|
470
|
+
|
471
|
+
return self._check_threshold(state)
|
472
|
+
|
473
|
+
def collect_masked_vectors_stage(
|
474
|
+
self, driver: Driver, context: LegacyContext, state: WorkflowState
|
475
|
+
) -> bool:
|
476
|
+
"""Execute the 'collect masked vectors' stage."""
|
477
|
+
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
478
|
+
|
479
|
+
# Send secret key shares to clients (plus FitIns) and collect masked vectors
|
480
|
+
def make(nid: int) -> Message:
|
481
|
+
cfgs_dict = {
|
482
|
+
Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
|
483
|
+
Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
|
484
|
+
Key.SOURCE_LIST: state.forward_srcs[nid],
|
485
|
+
}
|
486
|
+
cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
|
487
|
+
content = state.nid_to_fitins[nid]
|
488
|
+
content.configs_records[RECORD_KEY_CONFIGS] = cfgs_record
|
489
|
+
return driver.create_message(
|
490
|
+
content=content,
|
491
|
+
message_type=MessageType.TRAIN,
|
492
|
+
dst_node_id=nid,
|
493
|
+
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
494
|
+
ttl="",
|
495
|
+
)
|
496
|
+
|
497
|
+
log(
|
498
|
+
DEBUG,
|
499
|
+
"[Stage 2] Forwarding encrypted key shares to %s clients.",
|
500
|
+
len(state.active_node_ids),
|
501
|
+
)
|
502
|
+
msgs = driver.send_and_receive(
|
503
|
+
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
504
|
+
)
|
505
|
+
state.active_node_ids = {
|
506
|
+
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
507
|
+
}
|
508
|
+
log(
|
509
|
+
DEBUG,
|
510
|
+
"[Stage 2] Received masked vectors from %s clients.",
|
511
|
+
len(state.active_node_ids),
|
512
|
+
)
|
513
|
+
|
514
|
+
# Clear cache
|
515
|
+
del state.forward_ciphertexts, state.forward_srcs, state.nid_to_fitins
|
516
|
+
|
517
|
+
# Sum collected masked vectors and compute active/dead node IDs
|
518
|
+
masked_vector = None
|
519
|
+
for msg in msgs:
|
520
|
+
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
521
|
+
bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
|
522
|
+
client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
|
523
|
+
if masked_vector is None:
|
524
|
+
masked_vector = client_masked_vec
|
525
|
+
else:
|
526
|
+
masked_vector = parameters_addition(masked_vector, client_masked_vec)
|
527
|
+
if masked_vector is not None:
|
528
|
+
masked_vector = parameters_mod(masked_vector, state.mod_range)
|
529
|
+
state.aggregate_ndarrays = masked_vector
|
530
|
+
|
531
|
+
return self._check_threshold(state)
|
532
|
+
|
533
|
+
def unmask_stage( # pylint: disable=R0912, R0914, R0915
|
534
|
+
self, driver: Driver, context: LegacyContext, state: WorkflowState
|
535
|
+
) -> bool:
|
536
|
+
"""Execute the 'unmask' stage."""
|
537
|
+
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
538
|
+
current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
|
539
|
+
|
540
|
+
# Construct active node IDs and dead node IDs
|
541
|
+
active_nids = state.active_node_ids
|
542
|
+
dead_nids = state.sampled_node_ids - active_nids
|
543
|
+
|
544
|
+
# Send secure IDs of active and dead clients and collect key shares from clients
|
545
|
+
def make(nid: int) -> Message:
|
546
|
+
neighbours = state.nid_to_neighbours[nid]
|
547
|
+
cfgs_dict = {
|
548
|
+
Key.STAGE: Stage.UNMASK,
|
549
|
+
Key.ACTIVE_NODE_ID_LIST: list(neighbours & active_nids),
|
550
|
+
Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
|
551
|
+
}
|
552
|
+
cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
|
553
|
+
content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
|
554
|
+
return driver.create_message(
|
555
|
+
content=content,
|
556
|
+
message_type=MessageType.TRAIN,
|
557
|
+
dst_node_id=nid,
|
558
|
+
group_id=str(current_round),
|
559
|
+
ttl="",
|
560
|
+
)
|
561
|
+
|
562
|
+
log(
|
563
|
+
DEBUG,
|
564
|
+
"[Stage 3] Requesting key shares from %s clients to remove masks.",
|
565
|
+
len(state.active_node_ids),
|
566
|
+
)
|
567
|
+
msgs = driver.send_and_receive(
|
568
|
+
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
569
|
+
)
|
570
|
+
state.active_node_ids = {
|
571
|
+
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
572
|
+
}
|
573
|
+
log(
|
574
|
+
DEBUG,
|
575
|
+
"[Stage 3] Received key shares from %s clients.",
|
576
|
+
len(state.active_node_ids),
|
577
|
+
)
|
578
|
+
|
579
|
+
# Build collected shares dict
|
580
|
+
collected_shares_dict: Dict[int, List[bytes]] = {}
|
581
|
+
for nid in state.sampled_node_ids:
|
582
|
+
collected_shares_dict[nid] = []
|
583
|
+
for msg in msgs:
|
584
|
+
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
585
|
+
nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
|
586
|
+
shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
|
587
|
+
for owner_nid, share in zip(nids, shares):
|
588
|
+
collected_shares_dict[owner_nid].append(share)
|
589
|
+
|
590
|
+
# Remove masks for every active client after collect_masked_vectors stage
|
591
|
+
masked_vector = state.aggregate_ndarrays
|
592
|
+
del state.aggregate_ndarrays
|
593
|
+
for nid, share_list in collected_shares_dict.items():
|
594
|
+
if len(share_list) < state.threshold:
|
595
|
+
log(
|
596
|
+
ERROR, "Not enough shares to recover secret in unmask vectors stage"
|
597
|
+
)
|
598
|
+
return False
|
599
|
+
secret = combine_shares(share_list)
|
600
|
+
if nid in active_nids:
|
601
|
+
# The seed for PRG is the private mask seed of an active client.
|
602
|
+
private_mask = pseudo_rand_gen(
|
603
|
+
secret, state.mod_range, get_parameters_shape(masked_vector)
|
604
|
+
)
|
605
|
+
masked_vector = parameters_subtraction(masked_vector, private_mask)
|
606
|
+
else:
|
607
|
+
# The seed for PRG is the secret key 1 of a dropped client.
|
608
|
+
neighbours = state.nid_to_neighbours[nid]
|
609
|
+
neighbours.remove(nid)
|
610
|
+
|
611
|
+
for neighbor_nid in neighbours:
|
612
|
+
shared_key = generate_shared_key(
|
613
|
+
bytes_to_private_key(secret),
|
614
|
+
bytes_to_public_key(state.nid_to_publickeys[neighbor_nid][0]),
|
615
|
+
)
|
616
|
+
pairwise_mask = pseudo_rand_gen(
|
617
|
+
shared_key, state.mod_range, get_parameters_shape(masked_vector)
|
618
|
+
)
|
619
|
+
if nid > neighbor_nid:
|
620
|
+
masked_vector = parameters_addition(
|
621
|
+
masked_vector, pairwise_mask
|
622
|
+
)
|
623
|
+
else:
|
624
|
+
masked_vector = parameters_subtraction(
|
625
|
+
masked_vector, pairwise_mask
|
626
|
+
)
|
627
|
+
recon_parameters = parameters_mod(masked_vector, state.mod_range)
|
628
|
+
q_total_ratio, recon_parameters = factor_extract(recon_parameters)
|
629
|
+
inv_dq_total_ratio = state.quantization_range / q_total_ratio
|
630
|
+
# recon_parameters = parameters_divide(recon_parameters, total_weights_factor)
|
631
|
+
aggregated_vector = dequantize(
|
632
|
+
recon_parameters,
|
633
|
+
state.clipping_range,
|
634
|
+
state.quantization_range,
|
635
|
+
)
|
636
|
+
offset = -(len(active_nids) - 1) * state.clipping_range
|
637
|
+
for vec in aggregated_vector:
|
638
|
+
vec += offset
|
639
|
+
vec *= inv_dq_total_ratio
|
640
|
+
state.aggregate_ndarrays = aggregated_vector
|
641
|
+
|
642
|
+
# No exception/failure handling currently
|
643
|
+
log(
|
644
|
+
INFO,
|
645
|
+
"aggregate_fit: received %s results and %s failures",
|
646
|
+
1,
|
647
|
+
0,
|
648
|
+
)
|
649
|
+
|
650
|
+
final_fitres = FitRes(
|
651
|
+
status=Status(code=Code.OK, message=""),
|
652
|
+
parameters=ndarrays_to_parameters(aggregated_vector),
|
653
|
+
num_examples=round(state.max_weight / inv_dq_total_ratio),
|
654
|
+
metrics={},
|
655
|
+
)
|
656
|
+
empty_proxy = DriverClientProxy(
|
657
|
+
0,
|
658
|
+
driver.grpc_driver, # type: ignore
|
659
|
+
False,
|
660
|
+
driver.run_id, # type: ignore
|
661
|
+
)
|
662
|
+
aggregated_result = context.strategy.aggregate_fit(
|
663
|
+
current_round, [(empty_proxy, final_fitres)], []
|
664
|
+
)
|
665
|
+
parameters_aggregated, metrics_aggregated = aggregated_result
|
666
|
+
|
667
|
+
# Update the parameters and write history
|
668
|
+
if parameters_aggregated:
|
669
|
+
paramsrecord = compat.parameters_to_parametersrecord(
|
670
|
+
parameters_aggregated, True
|
671
|
+
)
|
672
|
+
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
673
|
+
context.history.add_metrics_distributed_fit(
|
674
|
+
server_round=current_round, metrics=metrics_aggregated
|
675
|
+
)
|
676
|
+
return True
|