flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/cli/flower_toml.py +4 -48
- flwr/cli/new/new.py +6 -3
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
- flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
- flwr/cli/utils.py +14 -1
- flwr/client/app.py +39 -5
- flwr/client/client_app.py +1 -47
- flwr/client/mod/__init__.py +2 -1
- flwr/client/mod/secure_aggregation/__init__.py +2 -0
- flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
- flwr/common/grpc.py +3 -3
- flwr/common/logger.py +78 -15
- flwr/common/object_ref.py +140 -0
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
- flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
- flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
- flwr/server/compat/app.py +2 -1
- flwr/server/driver/grpc_driver.py +4 -4
- flwr/server/history.py +22 -15
- flwr/server/run_serverapp.py +22 -4
- flwr/server/server.py +27 -23
- flwr/server/server_app.py +1 -47
- flwr/server/server_config.py +9 -0
- flwr/server/strategy/fedavg.py +2 -0
- flwr/server/superlink/fleet/vce/vce_api.py +9 -2
- flwr/server/superlink/state/in_memory_state.py +34 -32
- flwr/server/workflow/__init__.py +3 -0
- flwr/server/workflow/constant.py +32 -0
- flwr/server/workflow/default_workflows.py +52 -57
- flwr/server/workflow/secure_aggregation/__init__.py +24 -0
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,676 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Workflow for the SecAgg+ protocol."""
|
16
|
+
|
17
|
+
|
18
|
+
import random
|
19
|
+
from dataclasses import dataclass, field
|
20
|
+
from logging import DEBUG, ERROR, INFO, WARN
|
21
|
+
from typing import Dict, List, Optional, Set, Union, cast
|
22
|
+
|
23
|
+
import flwr.common.recordset_compat as compat
|
24
|
+
from flwr.common import (
|
25
|
+
Code,
|
26
|
+
ConfigsRecord,
|
27
|
+
Context,
|
28
|
+
FitRes,
|
29
|
+
Message,
|
30
|
+
MessageType,
|
31
|
+
NDArrays,
|
32
|
+
RecordSet,
|
33
|
+
Status,
|
34
|
+
bytes_to_ndarray,
|
35
|
+
log,
|
36
|
+
ndarrays_to_parameters,
|
37
|
+
)
|
38
|
+
from flwr.common.secure_aggregation.crypto.shamir import combine_shares
|
39
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
40
|
+
bytes_to_private_key,
|
41
|
+
bytes_to_public_key,
|
42
|
+
generate_shared_key,
|
43
|
+
)
|
44
|
+
from flwr.common.secure_aggregation.ndarrays_arithmetic import (
|
45
|
+
factor_extract,
|
46
|
+
get_parameters_shape,
|
47
|
+
parameters_addition,
|
48
|
+
parameters_mod,
|
49
|
+
parameters_subtraction,
|
50
|
+
)
|
51
|
+
from flwr.common.secure_aggregation.quantization import dequantize
|
52
|
+
from flwr.common.secure_aggregation.secaggplus_constants import (
|
53
|
+
RECORD_KEY_CONFIGS,
|
54
|
+
Key,
|
55
|
+
Stage,
|
56
|
+
)
|
57
|
+
from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
|
58
|
+
from flwr.server.compat.driver_client_proxy import DriverClientProxy
|
59
|
+
from flwr.server.compat.legacy_context import LegacyContext
|
60
|
+
from flwr.server.driver import Driver
|
61
|
+
|
62
|
+
from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
|
63
|
+
from ..constant import Key as WorkflowKey
|
64
|
+
|
65
|
+
|
66
|
+
@dataclass
|
67
|
+
class WorkflowState: # pylint: disable=R0902
|
68
|
+
"""The state of the SecAgg+ protocol."""
|
69
|
+
|
70
|
+
nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict)
|
71
|
+
sampled_node_ids: Set[int] = field(default_factory=set)
|
72
|
+
active_node_ids: Set[int] = field(default_factory=set)
|
73
|
+
num_shares: int = 0
|
74
|
+
threshold: int = 0
|
75
|
+
clipping_range: float = 0.0
|
76
|
+
quantization_range: int = 0
|
77
|
+
mod_range: int = 0
|
78
|
+
max_weight: float = 0.0
|
79
|
+
nid_to_neighbours: Dict[int, Set[int]] = field(default_factory=dict)
|
80
|
+
nid_to_publickeys: Dict[int, List[bytes]] = field(default_factory=dict)
|
81
|
+
forward_srcs: Dict[int, List[int]] = field(default_factory=dict)
|
82
|
+
forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
|
83
|
+
aggregate_ndarrays: NDArrays = field(default_factory=list)
|
84
|
+
|
85
|
+
|
86
|
+
class SecAggPlusWorkflow:
|
87
|
+
"""The workflow for the SecAgg+ protocol.
|
88
|
+
|
89
|
+
The SecAgg+ protocol ensures the secure summation of integer vectors owned by
|
90
|
+
multiple parties, without accessing any individual integer vector. This workflow
|
91
|
+
allows the server to compute the weighted average of model parameters across all
|
92
|
+
clients, ensuring individual contributions remain private. This is achieved by
|
93
|
+
clients sending both, a weighting factor and a weighted version of the locally
|
94
|
+
updated parameters, both of which are masked for privacy. Specifically, each
|
95
|
+
client uploads "[w, w * params]" with masks, where weighting factor 'w' is the
|
96
|
+
number of examples ('num_examples') and 'params' represents the model parameters
|
97
|
+
('parameters') from the client's `FitRes`. The server then aggregates these
|
98
|
+
contributions to compute the weighted average of model parameters.
|
99
|
+
|
100
|
+
The protocol involves four main stages:
|
101
|
+
- 'setup': Send SecAgg+ configuration to clients and collect their public keys.
|
102
|
+
- 'share keys': Broadcast public keys among clients and collect encrypted secret
|
103
|
+
key shares.
|
104
|
+
- 'collect masked vectors': Forward encrypted secret key shares to target clients
|
105
|
+
and collect masked model parameters.
|
106
|
+
- 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
|
107
|
+
|
108
|
+
Only the aggregated model parameters are exposed and passed to
|
109
|
+
`Strategy.aggregate_fit`, ensuring individual data privacy.
|
110
|
+
|
111
|
+
Parameters
|
112
|
+
----------
|
113
|
+
num_shares : Union[int, float]
|
114
|
+
The number of shares into which each client's private key is split under
|
115
|
+
the SecAgg+ protocol. If specified as a float, it represents the proportion
|
116
|
+
of all selected clients, and the number of shares will be set dynamically in
|
117
|
+
the run time. A private key can be reconstructed from these shares, allowing
|
118
|
+
for the secure aggregation of model updates. Each client sends one share to
|
119
|
+
each of its neighbors while retaining one.
|
120
|
+
reconstruction_threshold : Union[int, float]
|
121
|
+
The minimum number of shares required to reconstruct a client's private key,
|
122
|
+
or, if specified as a float, it represents the proportion of the total number
|
123
|
+
of shares needed for reconstruction. This threshold ensures privacy by allowing
|
124
|
+
for the recovery of contributions from dropped clients during aggregation,
|
125
|
+
without compromising individual client data.
|
126
|
+
max_weight : Optional[float] (default: 1000.0)
|
127
|
+
The maximum value of the weight that can be assigned to any single client's
|
128
|
+
update during the weighted average calculation on the server side, e.g., in the
|
129
|
+
FedAvg algorithm.
|
130
|
+
clipping_range : float, optional (default: 8.0)
|
131
|
+
The range within which model parameters are clipped before quantization.
|
132
|
+
This parameter ensures each model parameter is bounded within
|
133
|
+
[-clipping_range, clipping_range], facilitating quantization.
|
134
|
+
quantization_range : int, optional (default: 4194304, this equals 2**22)
|
135
|
+
The size of the range into which floating-point model parameters are quantized,
|
136
|
+
mapping each parameter to an integer in [0, quantization_range-1]. This
|
137
|
+
facilitates cryptographic operations on the model updates.
|
138
|
+
modulus_range : int, optional (default: 4294967296, this equals 2**32)
|
139
|
+
The range of values from which random mask entries are uniformly sampled
|
140
|
+
([0, modulus_range-1]). `modulus_range` must be less than 4294967296.
|
141
|
+
Please use 2**n values for `modulus_range` to prevent overflow issues.
|
142
|
+
timeout : Optional[float] (default: None)
|
143
|
+
The timeout duration in seconds. If specified, the workflow will wait for
|
144
|
+
replies for this duration each time. If `None`, there is no time limit and
|
145
|
+
the workflow will wait until replies for all messages are received.
|
146
|
+
|
147
|
+
Notes
|
148
|
+
-----
|
149
|
+
- Generally, higher `num_shares` means more robust to dropouts while increasing the
|
150
|
+
computational costs; higher `reconstruction_threshold` means better privacy
|
151
|
+
guarantees but less tolerance to dropouts.
|
152
|
+
- Too large `max_weight` may compromise the precision of the quantization.
|
153
|
+
- `modulus_range` must be 2**n and larger than `quantization_range`.
|
154
|
+
- When `num_shares` is a float, it is interpreted as the proportion of all selected
|
155
|
+
clients, and hence the number of shares will be determined in the runtime. This
|
156
|
+
allows for dynamic adjustment based on the total number of participating clients.
|
157
|
+
- Similarly, when `reconstruction_threshold` is a float, it is interpreted as the
|
158
|
+
proportion of the number of shares needed for the reconstruction of a private key.
|
159
|
+
This feature enables flexibility in setting the security threshold relative to the
|
160
|
+
number of distributed shares.
|
161
|
+
- `num_shares`, `reconstruction_threshold`, and the quantization parameters
|
162
|
+
(`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in
|
163
|
+
balancing privacy, robustness, and efficiency within the SecAgg+ protocol.
|
164
|
+
"""
|
165
|
+
|
166
|
+
def __init__( # pylint: disable=R0913
|
167
|
+
self,
|
168
|
+
num_shares: Union[int, float],
|
169
|
+
reconstruction_threshold: Union[int, float],
|
170
|
+
*,
|
171
|
+
max_weight: float = 1000.0,
|
172
|
+
clipping_range: float = 8.0,
|
173
|
+
quantization_range: int = 4194304,
|
174
|
+
modulus_range: int = 4294967296,
|
175
|
+
timeout: Optional[float] = None,
|
176
|
+
) -> None:
|
177
|
+
self.num_shares = num_shares
|
178
|
+
self.reconstruction_threshold = reconstruction_threshold
|
179
|
+
self.max_weight = max_weight
|
180
|
+
self.clipping_range = clipping_range
|
181
|
+
self.quantization_range = quantization_range
|
182
|
+
self.modulus_range = modulus_range
|
183
|
+
self.timeout = timeout
|
184
|
+
|
185
|
+
self._check_init_params()
|
186
|
+
|
187
|
+
def __call__(self, driver: Driver, context: Context) -> None:
|
188
|
+
"""Run the SecAgg+ protocol."""
|
189
|
+
if not isinstance(context, LegacyContext):
|
190
|
+
raise TypeError(
|
191
|
+
f"Expect a LegacyContext, but get {type(context).__name__}."
|
192
|
+
)
|
193
|
+
state = WorkflowState()
|
194
|
+
|
195
|
+
steps = (
|
196
|
+
self.setup_stage,
|
197
|
+
self.share_keys_stage,
|
198
|
+
self.collect_masked_vectors_stage,
|
199
|
+
self.unmask_stage,
|
200
|
+
)
|
201
|
+
log(INFO, "Secure aggregation commencing.")
|
202
|
+
for step in steps:
|
203
|
+
if not step(driver, context, state):
|
204
|
+
log(INFO, "Secure aggregation halted.")
|
205
|
+
return
|
206
|
+
log(INFO, "Secure aggregation completed.")
|
207
|
+
|
208
|
+
def _check_init_params(self) -> None: # pylint: disable=R0912
|
209
|
+
# Check `num_shares`
|
210
|
+
if not isinstance(self.num_shares, (int, float)):
|
211
|
+
raise TypeError("`num_shares` must be of type int or float.")
|
212
|
+
if isinstance(self.num_shares, int):
|
213
|
+
if self.num_shares == 1:
|
214
|
+
self.num_shares = 1.0
|
215
|
+
elif self.num_shares <= 2:
|
216
|
+
raise ValueError("`num_shares` as an integer must be greater than 2.")
|
217
|
+
elif self.num_shares > self.modulus_range / self.quantization_range:
|
218
|
+
log(
|
219
|
+
WARN,
|
220
|
+
"A `num_shares` larger than `modulus_range / quantization_range` "
|
221
|
+
"will potentially cause overflow when computing the aggregated "
|
222
|
+
"model parameters.",
|
223
|
+
)
|
224
|
+
elif self.num_shares <= 0:
|
225
|
+
raise ValueError("`num_shares` as a float must be greater than 0.")
|
226
|
+
|
227
|
+
# Check `reconstruction_threshold`
|
228
|
+
if not isinstance(self.reconstruction_threshold, (int, float)):
|
229
|
+
raise TypeError("`reconstruction_threshold` must be of type int or float.")
|
230
|
+
if isinstance(self.reconstruction_threshold, int):
|
231
|
+
if self.reconstruction_threshold == 1:
|
232
|
+
self.reconstruction_threshold = 1.0
|
233
|
+
elif isinstance(self.num_shares, int):
|
234
|
+
if self.reconstruction_threshold >= self.num_shares:
|
235
|
+
raise ValueError(
|
236
|
+
"`reconstruction_threshold` must be less than `num_shares`."
|
237
|
+
)
|
238
|
+
else:
|
239
|
+
if not 0 < self.reconstruction_threshold <= 1:
|
240
|
+
raise ValueError(
|
241
|
+
"If `reconstruction_threshold` is a float, "
|
242
|
+
"it must be greater than 0 and less than or equal to 1."
|
243
|
+
)
|
244
|
+
|
245
|
+
# Check `max_weight`
|
246
|
+
if self.max_weight <= 0:
|
247
|
+
raise ValueError("`max_weight` must be greater than 0.")
|
248
|
+
|
249
|
+
# Check `quantization_range`
|
250
|
+
if self.quantization_range <= 0:
|
251
|
+
raise ValueError("`quantization_range` must be greater than 0.")
|
252
|
+
|
253
|
+
# Check `quantization_range`
|
254
|
+
if not isinstance(self.quantization_range, int) or self.quantization_range <= 0:
|
255
|
+
raise ValueError(
|
256
|
+
"`quantization_range` must be an integer and greater than 0."
|
257
|
+
)
|
258
|
+
|
259
|
+
# Check `modulus_range`
|
260
|
+
if (
|
261
|
+
not isinstance(self.modulus_range, int)
|
262
|
+
or self.modulus_range <= self.quantization_range
|
263
|
+
):
|
264
|
+
raise ValueError(
|
265
|
+
"`modulus_range` must be an integer and "
|
266
|
+
"greater than `quantization_range`."
|
267
|
+
)
|
268
|
+
if bin(self.modulus_range).count("1") != 1:
|
269
|
+
raise ValueError("`modulus_range` must be a power of 2.")
|
270
|
+
|
271
|
+
def _check_threshold(self, state: WorkflowState) -> bool:
|
272
|
+
for node_id in state.sampled_node_ids:
|
273
|
+
active_neighbors = state.nid_to_neighbours[node_id] & state.active_node_ids
|
274
|
+
if len(active_neighbors) < state.threshold:
|
275
|
+
log(ERROR, "Insufficient available nodes.")
|
276
|
+
return False
|
277
|
+
return True
|
278
|
+
|
279
|
+
def setup_stage( # pylint: disable=R0912, R0914, R0915
|
280
|
+
self, driver: Driver, context: LegacyContext, state: WorkflowState
|
281
|
+
) -> bool:
|
282
|
+
"""Execute the 'setup' stage."""
|
283
|
+
# Obtain fit instructions
|
284
|
+
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
285
|
+
current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
|
286
|
+
parameters = compat.parametersrecord_to_parameters(
|
287
|
+
context.state.parameters_records[MAIN_PARAMS_RECORD],
|
288
|
+
keep_input=True,
|
289
|
+
)
|
290
|
+
proxy_fitins_lst = context.strategy.configure_fit(
|
291
|
+
current_round, parameters, context.client_manager
|
292
|
+
)
|
293
|
+
if not proxy_fitins_lst:
|
294
|
+
log(INFO, "configure_fit: no clients selected, cancel")
|
295
|
+
return False
|
296
|
+
log(
|
297
|
+
INFO,
|
298
|
+
"configure_fit: strategy sampled %s clients (out of %s)",
|
299
|
+
len(proxy_fitins_lst),
|
300
|
+
context.client_manager.num_available(),
|
301
|
+
)
|
302
|
+
|
303
|
+
state.nid_to_fitins = {
|
304
|
+
proxy.node_id: compat.fitins_to_recordset(fitins, False)
|
305
|
+
for proxy, fitins in proxy_fitins_lst
|
306
|
+
}
|
307
|
+
|
308
|
+
# Protocol config
|
309
|
+
sampled_node_ids = list(state.nid_to_fitins.keys())
|
310
|
+
num_samples = len(sampled_node_ids)
|
311
|
+
if num_samples < 2:
|
312
|
+
log(ERROR, "The number of samples should be greater than 1.")
|
313
|
+
return False
|
314
|
+
if isinstance(self.num_shares, float):
|
315
|
+
state.num_shares = round(self.num_shares * num_samples)
|
316
|
+
# If even
|
317
|
+
if state.num_shares < num_samples and state.num_shares & 1 == 0:
|
318
|
+
state.num_shares += 1
|
319
|
+
# If too small
|
320
|
+
if state.num_shares <= 2:
|
321
|
+
state.num_shares = num_samples
|
322
|
+
else:
|
323
|
+
state.num_shares = self.num_shares
|
324
|
+
if isinstance(self.reconstruction_threshold, float):
|
325
|
+
state.threshold = round(self.reconstruction_threshold * state.num_shares)
|
326
|
+
# Avoid too small threshold
|
327
|
+
state.threshold = max(state.threshold, 2)
|
328
|
+
else:
|
329
|
+
state.threshold = self.reconstruction_threshold
|
330
|
+
state.active_node_ids = set(sampled_node_ids)
|
331
|
+
state.clipping_range = self.clipping_range
|
332
|
+
state.quantization_range = self.quantization_range
|
333
|
+
state.mod_range = self.modulus_range
|
334
|
+
state.max_weight = self.max_weight
|
335
|
+
sa_params_dict = {
|
336
|
+
Key.STAGE: Stage.SETUP,
|
337
|
+
Key.SAMPLE_NUMBER: num_samples,
|
338
|
+
Key.SHARE_NUMBER: state.num_shares,
|
339
|
+
Key.THRESHOLD: state.threshold,
|
340
|
+
Key.CLIPPING_RANGE: state.clipping_range,
|
341
|
+
Key.TARGET_RANGE: state.quantization_range,
|
342
|
+
Key.MOD_RANGE: state.mod_range,
|
343
|
+
Key.MAX_WEIGHT: state.max_weight,
|
344
|
+
}
|
345
|
+
|
346
|
+
# The number of shares should better be odd in the SecAgg+ protocol.
|
347
|
+
if num_samples != state.num_shares and state.num_shares & 1 == 0:
|
348
|
+
log(WARN, "Number of shares in the SecAgg+ protocol should be odd.")
|
349
|
+
state.num_shares += 1
|
350
|
+
|
351
|
+
# Shuffle node IDs
|
352
|
+
random.shuffle(sampled_node_ids)
|
353
|
+
# Build neighbour relations (node ID -> secure IDs of neighbours)
|
354
|
+
half_share = state.num_shares >> 1
|
355
|
+
state.nid_to_neighbours = {
|
356
|
+
nid: {
|
357
|
+
sampled_node_ids[(idx + offset) % num_samples]
|
358
|
+
for offset in range(-half_share, half_share + 1)
|
359
|
+
}
|
360
|
+
for idx, nid in enumerate(sampled_node_ids)
|
361
|
+
}
|
362
|
+
|
363
|
+
state.sampled_node_ids = state.active_node_ids
|
364
|
+
|
365
|
+
# Send setup configuration to clients
|
366
|
+
cfgs_record = ConfigsRecord(sa_params_dict) # type: ignore
|
367
|
+
content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
|
368
|
+
|
369
|
+
def make(nid: int) -> Message:
|
370
|
+
return driver.create_message(
|
371
|
+
content=content,
|
372
|
+
message_type=MessageType.TRAIN,
|
373
|
+
dst_node_id=nid,
|
374
|
+
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
375
|
+
ttl="",
|
376
|
+
)
|
377
|
+
|
378
|
+
log(
|
379
|
+
DEBUG,
|
380
|
+
"[Stage 0] Sending configurations to %s clients.",
|
381
|
+
len(state.active_node_ids),
|
382
|
+
)
|
383
|
+
msgs = driver.send_and_receive(
|
384
|
+
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
385
|
+
)
|
386
|
+
state.active_node_ids = {
|
387
|
+
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
388
|
+
}
|
389
|
+
log(
|
390
|
+
DEBUG,
|
391
|
+
"[Stage 0] Received public keys from %s clients.",
|
392
|
+
len(state.active_node_ids),
|
393
|
+
)
|
394
|
+
|
395
|
+
for msg in msgs:
|
396
|
+
if msg.has_error():
|
397
|
+
continue
|
398
|
+
key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
399
|
+
node_id = msg.metadata.src_node_id
|
400
|
+
pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2]
|
401
|
+
state.nid_to_publickeys[node_id] = [cast(bytes, pk1), cast(bytes, pk2)]
|
402
|
+
|
403
|
+
return self._check_threshold(state)
|
404
|
+
|
405
|
+
def share_keys_stage( # pylint: disable=R0914
|
406
|
+
self, driver: Driver, context: LegacyContext, state: WorkflowState
|
407
|
+
) -> bool:
|
408
|
+
"""Execute the 'share keys' stage."""
|
409
|
+
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
410
|
+
|
411
|
+
def make(nid: int) -> Message:
|
412
|
+
neighbours = state.nid_to_neighbours[nid] & state.active_node_ids
|
413
|
+
cfgs_record = ConfigsRecord(
|
414
|
+
{str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
|
415
|
+
)
|
416
|
+
cfgs_record[Key.STAGE] = Stage.SHARE_KEYS
|
417
|
+
content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
|
418
|
+
return driver.create_message(
|
419
|
+
content=content,
|
420
|
+
message_type=MessageType.TRAIN,
|
421
|
+
dst_node_id=nid,
|
422
|
+
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
423
|
+
ttl="",
|
424
|
+
)
|
425
|
+
|
426
|
+
# Broadcast public keys to clients and receive secret key shares
|
427
|
+
log(
|
428
|
+
DEBUG,
|
429
|
+
"[Stage 1] Forwarding public keys to %s clients.",
|
430
|
+
len(state.active_node_ids),
|
431
|
+
)
|
432
|
+
msgs = driver.send_and_receive(
|
433
|
+
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
434
|
+
)
|
435
|
+
state.active_node_ids = {
|
436
|
+
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
437
|
+
}
|
438
|
+
log(
|
439
|
+
DEBUG,
|
440
|
+
"[Stage 1] Received encrypted key shares from %s clients.",
|
441
|
+
len(state.active_node_ids),
|
442
|
+
)
|
443
|
+
|
444
|
+
# Build forward packet list dictionary
|
445
|
+
srcs: List[int] = []
|
446
|
+
dsts: List[int] = []
|
447
|
+
ciphertexts: List[bytes] = []
|
448
|
+
fwd_ciphertexts: Dict[int, List[bytes]] = {
|
449
|
+
nid: [] for nid in state.active_node_ids
|
450
|
+
} # dest node ID -> list of ciphertexts
|
451
|
+
fwd_srcs: Dict[int, List[int]] = {
|
452
|
+
nid: [] for nid in state.active_node_ids
|
453
|
+
} # dest node ID -> list of src node IDs
|
454
|
+
for msg in msgs:
|
455
|
+
node_id = msg.metadata.src_node_id
|
456
|
+
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
457
|
+
dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
|
458
|
+
ctxt_lst = cast(List[bytes], res_dict[Key.CIPHERTEXT_LIST])
|
459
|
+
srcs += [node_id] * len(dst_lst)
|
460
|
+
dsts += dst_lst
|
461
|
+
ciphertexts += ctxt_lst
|
462
|
+
|
463
|
+
for src, dst, ciphertext in zip(srcs, dsts, ciphertexts):
|
464
|
+
if dst in fwd_ciphertexts:
|
465
|
+
fwd_ciphertexts[dst].append(ciphertext)
|
466
|
+
fwd_srcs[dst].append(src)
|
467
|
+
|
468
|
+
state.forward_srcs = fwd_srcs
|
469
|
+
state.forward_ciphertexts = fwd_ciphertexts
|
470
|
+
|
471
|
+
return self._check_threshold(state)
|
472
|
+
|
473
|
+
def collect_masked_vectors_stage(
|
474
|
+
self, driver: Driver, context: LegacyContext, state: WorkflowState
|
475
|
+
) -> bool:
|
476
|
+
"""Execute the 'collect masked vectors' stage."""
|
477
|
+
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
478
|
+
|
479
|
+
# Send secret key shares to clients (plus FitIns) and collect masked vectors
|
480
|
+
def make(nid: int) -> Message:
|
481
|
+
cfgs_dict = {
|
482
|
+
Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
|
483
|
+
Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
|
484
|
+
Key.SOURCE_LIST: state.forward_srcs[nid],
|
485
|
+
}
|
486
|
+
cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
|
487
|
+
content = state.nid_to_fitins[nid]
|
488
|
+
content.configs_records[RECORD_KEY_CONFIGS] = cfgs_record
|
489
|
+
return driver.create_message(
|
490
|
+
content=content,
|
491
|
+
message_type=MessageType.TRAIN,
|
492
|
+
dst_node_id=nid,
|
493
|
+
group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
|
494
|
+
ttl="",
|
495
|
+
)
|
496
|
+
|
497
|
+
log(
|
498
|
+
DEBUG,
|
499
|
+
"[Stage 2] Forwarding encrypted key shares to %s clients.",
|
500
|
+
len(state.active_node_ids),
|
501
|
+
)
|
502
|
+
msgs = driver.send_and_receive(
|
503
|
+
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
504
|
+
)
|
505
|
+
state.active_node_ids = {
|
506
|
+
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
507
|
+
}
|
508
|
+
log(
|
509
|
+
DEBUG,
|
510
|
+
"[Stage 2] Received masked vectors from %s clients.",
|
511
|
+
len(state.active_node_ids),
|
512
|
+
)
|
513
|
+
|
514
|
+
# Clear cache
|
515
|
+
del state.forward_ciphertexts, state.forward_srcs, state.nid_to_fitins
|
516
|
+
|
517
|
+
# Sum collected masked vectors and compute active/dead node IDs
|
518
|
+
masked_vector = None
|
519
|
+
for msg in msgs:
|
520
|
+
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
521
|
+
bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
|
522
|
+
client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
|
523
|
+
if masked_vector is None:
|
524
|
+
masked_vector = client_masked_vec
|
525
|
+
else:
|
526
|
+
masked_vector = parameters_addition(masked_vector, client_masked_vec)
|
527
|
+
if masked_vector is not None:
|
528
|
+
masked_vector = parameters_mod(masked_vector, state.mod_range)
|
529
|
+
state.aggregate_ndarrays = masked_vector
|
530
|
+
|
531
|
+
return self._check_threshold(state)
|
532
|
+
|
533
|
+
def unmask_stage( # pylint: disable=R0912, R0914, R0915
|
534
|
+
self, driver: Driver, context: LegacyContext, state: WorkflowState
|
535
|
+
) -> bool:
|
536
|
+
"""Execute the 'unmask' stage."""
|
537
|
+
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
538
|
+
current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
|
539
|
+
|
540
|
+
# Construct active node IDs and dead node IDs
|
541
|
+
active_nids = state.active_node_ids
|
542
|
+
dead_nids = state.sampled_node_ids - active_nids
|
543
|
+
|
544
|
+
# Send secure IDs of active and dead clients and collect key shares from clients
|
545
|
+
def make(nid: int) -> Message:
|
546
|
+
neighbours = state.nid_to_neighbours[nid]
|
547
|
+
cfgs_dict = {
|
548
|
+
Key.STAGE: Stage.UNMASK,
|
549
|
+
Key.ACTIVE_NODE_ID_LIST: list(neighbours & active_nids),
|
550
|
+
Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
|
551
|
+
}
|
552
|
+
cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
|
553
|
+
content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
|
554
|
+
return driver.create_message(
|
555
|
+
content=content,
|
556
|
+
message_type=MessageType.TRAIN,
|
557
|
+
dst_node_id=nid,
|
558
|
+
group_id=str(current_round),
|
559
|
+
ttl="",
|
560
|
+
)
|
561
|
+
|
562
|
+
log(
|
563
|
+
DEBUG,
|
564
|
+
"[Stage 3] Requesting key shares from %s clients to remove masks.",
|
565
|
+
len(state.active_node_ids),
|
566
|
+
)
|
567
|
+
msgs = driver.send_and_receive(
|
568
|
+
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
569
|
+
)
|
570
|
+
state.active_node_ids = {
|
571
|
+
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
572
|
+
}
|
573
|
+
log(
|
574
|
+
DEBUG,
|
575
|
+
"[Stage 3] Received key shares from %s clients.",
|
576
|
+
len(state.active_node_ids),
|
577
|
+
)
|
578
|
+
|
579
|
+
# Build collected shares dict
|
580
|
+
collected_shares_dict: Dict[int, List[bytes]] = {}
|
581
|
+
for nid in state.sampled_node_ids:
|
582
|
+
collected_shares_dict[nid] = []
|
583
|
+
for msg in msgs:
|
584
|
+
res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
|
585
|
+
nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
|
586
|
+
shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
|
587
|
+
for owner_nid, share in zip(nids, shares):
|
588
|
+
collected_shares_dict[owner_nid].append(share)
|
589
|
+
|
590
|
+
# Remove masks for every active client after collect_masked_vectors stage
|
591
|
+
masked_vector = state.aggregate_ndarrays
|
592
|
+
del state.aggregate_ndarrays
|
593
|
+
for nid, share_list in collected_shares_dict.items():
|
594
|
+
if len(share_list) < state.threshold:
|
595
|
+
log(
|
596
|
+
ERROR, "Not enough shares to recover secret in unmask vectors stage"
|
597
|
+
)
|
598
|
+
return False
|
599
|
+
secret = combine_shares(share_list)
|
600
|
+
if nid in active_nids:
|
601
|
+
# The seed for PRG is the private mask seed of an active client.
|
602
|
+
private_mask = pseudo_rand_gen(
|
603
|
+
secret, state.mod_range, get_parameters_shape(masked_vector)
|
604
|
+
)
|
605
|
+
masked_vector = parameters_subtraction(masked_vector, private_mask)
|
606
|
+
else:
|
607
|
+
# The seed for PRG is the secret key 1 of a dropped client.
|
608
|
+
neighbours = state.nid_to_neighbours[nid]
|
609
|
+
neighbours.remove(nid)
|
610
|
+
|
611
|
+
for neighbor_nid in neighbours:
|
612
|
+
shared_key = generate_shared_key(
|
613
|
+
bytes_to_private_key(secret),
|
614
|
+
bytes_to_public_key(state.nid_to_publickeys[neighbor_nid][0]),
|
615
|
+
)
|
616
|
+
pairwise_mask = pseudo_rand_gen(
|
617
|
+
shared_key, state.mod_range, get_parameters_shape(masked_vector)
|
618
|
+
)
|
619
|
+
if nid > neighbor_nid:
|
620
|
+
masked_vector = parameters_addition(
|
621
|
+
masked_vector, pairwise_mask
|
622
|
+
)
|
623
|
+
else:
|
624
|
+
masked_vector = parameters_subtraction(
|
625
|
+
masked_vector, pairwise_mask
|
626
|
+
)
|
627
|
+
recon_parameters = parameters_mod(masked_vector, state.mod_range)
|
628
|
+
q_total_ratio, recon_parameters = factor_extract(recon_parameters)
|
629
|
+
inv_dq_total_ratio = state.quantization_range / q_total_ratio
|
630
|
+
# recon_parameters = parameters_divide(recon_parameters, total_weights_factor)
|
631
|
+
aggregated_vector = dequantize(
|
632
|
+
recon_parameters,
|
633
|
+
state.clipping_range,
|
634
|
+
state.quantization_range,
|
635
|
+
)
|
636
|
+
offset = -(len(active_nids) - 1) * state.clipping_range
|
637
|
+
for vec in aggregated_vector:
|
638
|
+
vec += offset
|
639
|
+
vec *= inv_dq_total_ratio
|
640
|
+
state.aggregate_ndarrays = aggregated_vector
|
641
|
+
|
642
|
+
# No exception/failure handling currently
|
643
|
+
log(
|
644
|
+
INFO,
|
645
|
+
"aggregate_fit: received %s results and %s failures",
|
646
|
+
1,
|
647
|
+
0,
|
648
|
+
)
|
649
|
+
|
650
|
+
final_fitres = FitRes(
|
651
|
+
status=Status(code=Code.OK, message=""),
|
652
|
+
parameters=ndarrays_to_parameters(aggregated_vector),
|
653
|
+
num_examples=round(state.max_weight / inv_dq_total_ratio),
|
654
|
+
metrics={},
|
655
|
+
)
|
656
|
+
empty_proxy = DriverClientProxy(
|
657
|
+
0,
|
658
|
+
driver.grpc_driver, # type: ignore
|
659
|
+
False,
|
660
|
+
driver.run_id, # type: ignore
|
661
|
+
)
|
662
|
+
aggregated_result = context.strategy.aggregate_fit(
|
663
|
+
current_round, [(empty_proxy, final_fitres)], []
|
664
|
+
)
|
665
|
+
parameters_aggregated, metrics_aggregated = aggregated_result
|
666
|
+
|
667
|
+
# Update the parameters and write history
|
668
|
+
if parameters_aggregated:
|
669
|
+
paramsrecord = compat.parameters_to_parametersrecord(
|
670
|
+
parameters_aggregated, True
|
671
|
+
)
|
672
|
+
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
673
|
+
context.history.add_metrics_distributed_fit(
|
674
|
+
server_round=current_round, metrics=metrics_aggregated
|
675
|
+
)
|
676
|
+
return True
|