flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (39) hide show
  1. flwr/cli/flower_toml.py +4 -48
  2. flwr/cli/new/new.py +6 -3
  3. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
  4. flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
  6. flwr/cli/utils.py +14 -1
  7. flwr/client/app.py +39 -5
  8. flwr/client/client_app.py +1 -47
  9. flwr/client/mod/__init__.py +2 -1
  10. flwr/client/mod/secure_aggregation/__init__.py +2 -0
  11. flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
  12. flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
  13. flwr/common/grpc.py +3 -3
  14. flwr/common/logger.py +78 -15
  15. flwr/common/object_ref.py +140 -0
  16. flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
  17. flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
  18. flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
  19. flwr/server/compat/app.py +2 -1
  20. flwr/server/driver/grpc_driver.py +4 -4
  21. flwr/server/history.py +22 -15
  22. flwr/server/run_serverapp.py +22 -4
  23. flwr/server/server.py +27 -23
  24. flwr/server/server_app.py +1 -47
  25. flwr/server/server_config.py +9 -0
  26. flwr/server/strategy/fedavg.py +2 -0
  27. flwr/server/superlink/fleet/vce/vce_api.py +9 -2
  28. flwr/server/superlink/state/in_memory_state.py +34 -32
  29. flwr/server/workflow/__init__.py +3 -0
  30. flwr/server/workflow/constant.py +32 -0
  31. flwr/server/workflow/default_workflows.py +52 -57
  32. flwr/server/workflow/secure_aggregation/__init__.py +24 -0
  33. flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
  34. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
  35. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
  36. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
  37. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
  38. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
  39. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,676 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Workflow for the SecAgg+ protocol."""
16
+
17
+
18
+ import random
19
+ from dataclasses import dataclass, field
20
+ from logging import DEBUG, ERROR, INFO, WARN
21
+ from typing import Dict, List, Optional, Set, Union, cast
22
+
23
+ import flwr.common.recordset_compat as compat
24
+ from flwr.common import (
25
+ Code,
26
+ ConfigsRecord,
27
+ Context,
28
+ FitRes,
29
+ Message,
30
+ MessageType,
31
+ NDArrays,
32
+ RecordSet,
33
+ Status,
34
+ bytes_to_ndarray,
35
+ log,
36
+ ndarrays_to_parameters,
37
+ )
38
+ from flwr.common.secure_aggregation.crypto.shamir import combine_shares
39
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
40
+ bytes_to_private_key,
41
+ bytes_to_public_key,
42
+ generate_shared_key,
43
+ )
44
+ from flwr.common.secure_aggregation.ndarrays_arithmetic import (
45
+ factor_extract,
46
+ get_parameters_shape,
47
+ parameters_addition,
48
+ parameters_mod,
49
+ parameters_subtraction,
50
+ )
51
+ from flwr.common.secure_aggregation.quantization import dequantize
52
+ from flwr.common.secure_aggregation.secaggplus_constants import (
53
+ RECORD_KEY_CONFIGS,
54
+ Key,
55
+ Stage,
56
+ )
57
+ from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
58
+ from flwr.server.compat.driver_client_proxy import DriverClientProxy
59
+ from flwr.server.compat.legacy_context import LegacyContext
60
+ from flwr.server.driver import Driver
61
+
62
+ from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
63
+ from ..constant import Key as WorkflowKey
64
+
65
+
66
+ @dataclass
67
+ class WorkflowState: # pylint: disable=R0902
68
+ """The state of the SecAgg+ protocol."""
69
+
70
+ nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict)
71
+ sampled_node_ids: Set[int] = field(default_factory=set)
72
+ active_node_ids: Set[int] = field(default_factory=set)
73
+ num_shares: int = 0
74
+ threshold: int = 0
75
+ clipping_range: float = 0.0
76
+ quantization_range: int = 0
77
+ mod_range: int = 0
78
+ max_weight: float = 0.0
79
+ nid_to_neighbours: Dict[int, Set[int]] = field(default_factory=dict)
80
+ nid_to_publickeys: Dict[int, List[bytes]] = field(default_factory=dict)
81
+ forward_srcs: Dict[int, List[int]] = field(default_factory=dict)
82
+ forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
83
+ aggregate_ndarrays: NDArrays = field(default_factory=list)
84
+
85
+
86
+ class SecAggPlusWorkflow:
87
+ """The workflow for the SecAgg+ protocol.
88
+
89
+ The SecAgg+ protocol ensures the secure summation of integer vectors owned by
90
+ multiple parties, without accessing any individual integer vector. This workflow
91
+ allows the server to compute the weighted average of model parameters across all
92
+ clients, ensuring individual contributions remain private. This is achieved by
93
+ clients sending both, a weighting factor and a weighted version of the locally
94
+ updated parameters, both of which are masked for privacy. Specifically, each
95
+ client uploads "[w, w * params]" with masks, where weighting factor 'w' is the
96
+ number of examples ('num_examples') and 'params' represents the model parameters
97
+ ('parameters') from the client's `FitRes`. The server then aggregates these
98
+ contributions to compute the weighted average of model parameters.
99
+
100
+ The protocol involves four main stages:
101
+ - 'setup': Send SecAgg+ configuration to clients and collect their public keys.
102
+ - 'share keys': Broadcast public keys among clients and collect encrypted secret
103
+ key shares.
104
+ - 'collect masked vectors': Forward encrypted secret key shares to target clients
105
+ and collect masked model parameters.
106
+ - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
107
+
108
+ Only the aggregated model parameters are exposed and passed to
109
+ `Strategy.aggregate_fit`, ensuring individual data privacy.
110
+
111
+ Parameters
112
+ ----------
113
+ num_shares : Union[int, float]
114
+ The number of shares into which each client's private key is split under
115
+ the SecAgg+ protocol. If specified as a float, it represents the proportion
116
+ of all selected clients, and the number of shares will be set dynamically in
117
+ the run time. A private key can be reconstructed from these shares, allowing
118
+ for the secure aggregation of model updates. Each client sends one share to
119
+ each of its neighbors while retaining one.
120
+ reconstruction_threshold : Union[int, float]
121
+ The minimum number of shares required to reconstruct a client's private key,
122
+ or, if specified as a float, it represents the proportion of the total number
123
+ of shares needed for reconstruction. This threshold ensures privacy by allowing
124
+ for the recovery of contributions from dropped clients during aggregation,
125
+ without compromising individual client data.
126
+ max_weight : Optional[float] (default: 1000.0)
127
+ The maximum value of the weight that can be assigned to any single client's
128
+ update during the weighted average calculation on the server side, e.g., in the
129
+ FedAvg algorithm.
130
+ clipping_range : float, optional (default: 8.0)
131
+ The range within which model parameters are clipped before quantization.
132
+ This parameter ensures each model parameter is bounded within
133
+ [-clipping_range, clipping_range], facilitating quantization.
134
+ quantization_range : int, optional (default: 4194304, this equals 2**22)
135
+ The size of the range into which floating-point model parameters are quantized,
136
+ mapping each parameter to an integer in [0, quantization_range-1]. This
137
+ facilitates cryptographic operations on the model updates.
138
+ modulus_range : int, optional (default: 4294967296, this equals 2**32)
139
+ The range of values from which random mask entries are uniformly sampled
140
+ ([0, modulus_range-1]). `modulus_range` must be less than 4294967296.
141
+ Please use 2**n values for `modulus_range` to prevent overflow issues.
142
+ timeout : Optional[float] (default: None)
143
+ The timeout duration in seconds. If specified, the workflow will wait for
144
+ replies for this duration each time. If `None`, there is no time limit and
145
+ the workflow will wait until replies for all messages are received.
146
+
147
+ Notes
148
+ -----
149
+ - Generally, higher `num_shares` means more robust to dropouts while increasing the
150
+ computational costs; higher `reconstruction_threshold` means better privacy
151
+ guarantees but less tolerance to dropouts.
152
+ - Too large `max_weight` may compromise the precision of the quantization.
153
+ - `modulus_range` must be 2**n and larger than `quantization_range`.
154
+ - When `num_shares` is a float, it is interpreted as the proportion of all selected
155
+ clients, and hence the number of shares will be determined in the runtime. This
156
+ allows for dynamic adjustment based on the total number of participating clients.
157
+ - Similarly, when `reconstruction_threshold` is a float, it is interpreted as the
158
+ proportion of the number of shares needed for the reconstruction of a private key.
159
+ This feature enables flexibility in setting the security threshold relative to the
160
+ number of distributed shares.
161
+ - `num_shares`, `reconstruction_threshold`, and the quantization parameters
162
+ (`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in
163
+ balancing privacy, robustness, and efficiency within the SecAgg+ protocol.
164
+ """
165
+
166
+ def __init__( # pylint: disable=R0913
167
+ self,
168
+ num_shares: Union[int, float],
169
+ reconstruction_threshold: Union[int, float],
170
+ *,
171
+ max_weight: float = 1000.0,
172
+ clipping_range: float = 8.0,
173
+ quantization_range: int = 4194304,
174
+ modulus_range: int = 4294967296,
175
+ timeout: Optional[float] = None,
176
+ ) -> None:
177
+ self.num_shares = num_shares
178
+ self.reconstruction_threshold = reconstruction_threshold
179
+ self.max_weight = max_weight
180
+ self.clipping_range = clipping_range
181
+ self.quantization_range = quantization_range
182
+ self.modulus_range = modulus_range
183
+ self.timeout = timeout
184
+
185
+ self._check_init_params()
186
+
187
+ def __call__(self, driver: Driver, context: Context) -> None:
188
+ """Run the SecAgg+ protocol."""
189
+ if not isinstance(context, LegacyContext):
190
+ raise TypeError(
191
+ f"Expect a LegacyContext, but get {type(context).__name__}."
192
+ )
193
+ state = WorkflowState()
194
+
195
+ steps = (
196
+ self.setup_stage,
197
+ self.share_keys_stage,
198
+ self.collect_masked_vectors_stage,
199
+ self.unmask_stage,
200
+ )
201
+ log(INFO, "Secure aggregation commencing.")
202
+ for step in steps:
203
+ if not step(driver, context, state):
204
+ log(INFO, "Secure aggregation halted.")
205
+ return
206
+ log(INFO, "Secure aggregation completed.")
207
+
208
+ def _check_init_params(self) -> None: # pylint: disable=R0912
209
+ # Check `num_shares`
210
+ if not isinstance(self.num_shares, (int, float)):
211
+ raise TypeError("`num_shares` must be of type int or float.")
212
+ if isinstance(self.num_shares, int):
213
+ if self.num_shares == 1:
214
+ self.num_shares = 1.0
215
+ elif self.num_shares <= 2:
216
+ raise ValueError("`num_shares` as an integer must be greater than 2.")
217
+ elif self.num_shares > self.modulus_range / self.quantization_range:
218
+ log(
219
+ WARN,
220
+ "A `num_shares` larger than `modulus_range / quantization_range` "
221
+ "will potentially cause overflow when computing the aggregated "
222
+ "model parameters.",
223
+ )
224
+ elif self.num_shares <= 0:
225
+ raise ValueError("`num_shares` as a float must be greater than 0.")
226
+
227
+ # Check `reconstruction_threshold`
228
+ if not isinstance(self.reconstruction_threshold, (int, float)):
229
+ raise TypeError("`reconstruction_threshold` must be of type int or float.")
230
+ if isinstance(self.reconstruction_threshold, int):
231
+ if self.reconstruction_threshold == 1:
232
+ self.reconstruction_threshold = 1.0
233
+ elif isinstance(self.num_shares, int):
234
+ if self.reconstruction_threshold >= self.num_shares:
235
+ raise ValueError(
236
+ "`reconstruction_threshold` must be less than `num_shares`."
237
+ )
238
+ else:
239
+ if not 0 < self.reconstruction_threshold <= 1:
240
+ raise ValueError(
241
+ "If `reconstruction_threshold` is a float, "
242
+ "it must be greater than 0 and less than or equal to 1."
243
+ )
244
+
245
+ # Check `max_weight`
246
+ if self.max_weight <= 0:
247
+ raise ValueError("`max_weight` must be greater than 0.")
248
+
249
+ # Check `quantization_range`
250
+ if self.quantization_range <= 0:
251
+ raise ValueError("`quantization_range` must be greater than 0.")
252
+
253
+ # Check `quantization_range`
254
+ if not isinstance(self.quantization_range, int) or self.quantization_range <= 0:
255
+ raise ValueError(
256
+ "`quantization_range` must be an integer and greater than 0."
257
+ )
258
+
259
+ # Check `modulus_range`
260
+ if (
261
+ not isinstance(self.modulus_range, int)
262
+ or self.modulus_range <= self.quantization_range
263
+ ):
264
+ raise ValueError(
265
+ "`modulus_range` must be an integer and "
266
+ "greater than `quantization_range`."
267
+ )
268
+ if bin(self.modulus_range).count("1") != 1:
269
+ raise ValueError("`modulus_range` must be a power of 2.")
270
+
271
+ def _check_threshold(self, state: WorkflowState) -> bool:
272
+ for node_id in state.sampled_node_ids:
273
+ active_neighbors = state.nid_to_neighbours[node_id] & state.active_node_ids
274
+ if len(active_neighbors) < state.threshold:
275
+ log(ERROR, "Insufficient available nodes.")
276
+ return False
277
+ return True
278
+
279
+ def setup_stage( # pylint: disable=R0912, R0914, R0915
280
+ self, driver: Driver, context: LegacyContext, state: WorkflowState
281
+ ) -> bool:
282
+ """Execute the 'setup' stage."""
283
+ # Obtain fit instructions
284
+ cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
285
+ current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
286
+ parameters = compat.parametersrecord_to_parameters(
287
+ context.state.parameters_records[MAIN_PARAMS_RECORD],
288
+ keep_input=True,
289
+ )
290
+ proxy_fitins_lst = context.strategy.configure_fit(
291
+ current_round, parameters, context.client_manager
292
+ )
293
+ if not proxy_fitins_lst:
294
+ log(INFO, "configure_fit: no clients selected, cancel")
295
+ return False
296
+ log(
297
+ INFO,
298
+ "configure_fit: strategy sampled %s clients (out of %s)",
299
+ len(proxy_fitins_lst),
300
+ context.client_manager.num_available(),
301
+ )
302
+
303
+ state.nid_to_fitins = {
304
+ proxy.node_id: compat.fitins_to_recordset(fitins, False)
305
+ for proxy, fitins in proxy_fitins_lst
306
+ }
307
+
308
+ # Protocol config
309
+ sampled_node_ids = list(state.nid_to_fitins.keys())
310
+ num_samples = len(sampled_node_ids)
311
+ if num_samples < 2:
312
+ log(ERROR, "The number of samples should be greater than 1.")
313
+ return False
314
+ if isinstance(self.num_shares, float):
315
+ state.num_shares = round(self.num_shares * num_samples)
316
+ # If even
317
+ if state.num_shares < num_samples and state.num_shares & 1 == 0:
318
+ state.num_shares += 1
319
+ # If too small
320
+ if state.num_shares <= 2:
321
+ state.num_shares = num_samples
322
+ else:
323
+ state.num_shares = self.num_shares
324
+ if isinstance(self.reconstruction_threshold, float):
325
+ state.threshold = round(self.reconstruction_threshold * state.num_shares)
326
+ # Avoid too small threshold
327
+ state.threshold = max(state.threshold, 2)
328
+ else:
329
+ state.threshold = self.reconstruction_threshold
330
+ state.active_node_ids = set(sampled_node_ids)
331
+ state.clipping_range = self.clipping_range
332
+ state.quantization_range = self.quantization_range
333
+ state.mod_range = self.modulus_range
334
+ state.max_weight = self.max_weight
335
+ sa_params_dict = {
336
+ Key.STAGE: Stage.SETUP,
337
+ Key.SAMPLE_NUMBER: num_samples,
338
+ Key.SHARE_NUMBER: state.num_shares,
339
+ Key.THRESHOLD: state.threshold,
340
+ Key.CLIPPING_RANGE: state.clipping_range,
341
+ Key.TARGET_RANGE: state.quantization_range,
342
+ Key.MOD_RANGE: state.mod_range,
343
+ Key.MAX_WEIGHT: state.max_weight,
344
+ }
345
+
346
+ # The number of shares should better be odd in the SecAgg+ protocol.
347
+ if num_samples != state.num_shares and state.num_shares & 1 == 0:
348
+ log(WARN, "Number of shares in the SecAgg+ protocol should be odd.")
349
+ state.num_shares += 1
350
+
351
+ # Shuffle node IDs
352
+ random.shuffle(sampled_node_ids)
353
+ # Build neighbour relations (node ID -> secure IDs of neighbours)
354
+ half_share = state.num_shares >> 1
355
+ state.nid_to_neighbours = {
356
+ nid: {
357
+ sampled_node_ids[(idx + offset) % num_samples]
358
+ for offset in range(-half_share, half_share + 1)
359
+ }
360
+ for idx, nid in enumerate(sampled_node_ids)
361
+ }
362
+
363
+ state.sampled_node_ids = state.active_node_ids
364
+
365
+ # Send setup configuration to clients
366
+ cfgs_record = ConfigsRecord(sa_params_dict) # type: ignore
367
+ content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
368
+
369
+ def make(nid: int) -> Message:
370
+ return driver.create_message(
371
+ content=content,
372
+ message_type=MessageType.TRAIN,
373
+ dst_node_id=nid,
374
+ group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
375
+ ttl="",
376
+ )
377
+
378
+ log(
379
+ DEBUG,
380
+ "[Stage 0] Sending configurations to %s clients.",
381
+ len(state.active_node_ids),
382
+ )
383
+ msgs = driver.send_and_receive(
384
+ [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
385
+ )
386
+ state.active_node_ids = {
387
+ msg.metadata.src_node_id for msg in msgs if not msg.has_error()
388
+ }
389
+ log(
390
+ DEBUG,
391
+ "[Stage 0] Received public keys from %s clients.",
392
+ len(state.active_node_ids),
393
+ )
394
+
395
+ for msg in msgs:
396
+ if msg.has_error():
397
+ continue
398
+ key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
399
+ node_id = msg.metadata.src_node_id
400
+ pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2]
401
+ state.nid_to_publickeys[node_id] = [cast(bytes, pk1), cast(bytes, pk2)]
402
+
403
+ return self._check_threshold(state)
404
+
405
+ def share_keys_stage( # pylint: disable=R0914
406
+ self, driver: Driver, context: LegacyContext, state: WorkflowState
407
+ ) -> bool:
408
+ """Execute the 'share keys' stage."""
409
+ cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
410
+
411
+ def make(nid: int) -> Message:
412
+ neighbours = state.nid_to_neighbours[nid] & state.active_node_ids
413
+ cfgs_record = ConfigsRecord(
414
+ {str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
415
+ )
416
+ cfgs_record[Key.STAGE] = Stage.SHARE_KEYS
417
+ content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
418
+ return driver.create_message(
419
+ content=content,
420
+ message_type=MessageType.TRAIN,
421
+ dst_node_id=nid,
422
+ group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
423
+ ttl="",
424
+ )
425
+
426
+ # Broadcast public keys to clients and receive secret key shares
427
+ log(
428
+ DEBUG,
429
+ "[Stage 1] Forwarding public keys to %s clients.",
430
+ len(state.active_node_ids),
431
+ )
432
+ msgs = driver.send_and_receive(
433
+ [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
434
+ )
435
+ state.active_node_ids = {
436
+ msg.metadata.src_node_id for msg in msgs if not msg.has_error()
437
+ }
438
+ log(
439
+ DEBUG,
440
+ "[Stage 1] Received encrypted key shares from %s clients.",
441
+ len(state.active_node_ids),
442
+ )
443
+
444
+ # Build forward packet list dictionary
445
+ srcs: List[int] = []
446
+ dsts: List[int] = []
447
+ ciphertexts: List[bytes] = []
448
+ fwd_ciphertexts: Dict[int, List[bytes]] = {
449
+ nid: [] for nid in state.active_node_ids
450
+ } # dest node ID -> list of ciphertexts
451
+ fwd_srcs: Dict[int, List[int]] = {
452
+ nid: [] for nid in state.active_node_ids
453
+ } # dest node ID -> list of src node IDs
454
+ for msg in msgs:
455
+ node_id = msg.metadata.src_node_id
456
+ res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
457
+ dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
458
+ ctxt_lst = cast(List[bytes], res_dict[Key.CIPHERTEXT_LIST])
459
+ srcs += [node_id] * len(dst_lst)
460
+ dsts += dst_lst
461
+ ciphertexts += ctxt_lst
462
+
463
+ for src, dst, ciphertext in zip(srcs, dsts, ciphertexts):
464
+ if dst in fwd_ciphertexts:
465
+ fwd_ciphertexts[dst].append(ciphertext)
466
+ fwd_srcs[dst].append(src)
467
+
468
+ state.forward_srcs = fwd_srcs
469
+ state.forward_ciphertexts = fwd_ciphertexts
470
+
471
+ return self._check_threshold(state)
472
+
473
+ def collect_masked_vectors_stage(
474
+ self, driver: Driver, context: LegacyContext, state: WorkflowState
475
+ ) -> bool:
476
+ """Execute the 'collect masked vectors' stage."""
477
+ cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
478
+
479
+ # Send secret key shares to clients (plus FitIns) and collect masked vectors
480
+ def make(nid: int) -> Message:
481
+ cfgs_dict = {
482
+ Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
483
+ Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
484
+ Key.SOURCE_LIST: state.forward_srcs[nid],
485
+ }
486
+ cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
487
+ content = state.nid_to_fitins[nid]
488
+ content.configs_records[RECORD_KEY_CONFIGS] = cfgs_record
489
+ return driver.create_message(
490
+ content=content,
491
+ message_type=MessageType.TRAIN,
492
+ dst_node_id=nid,
493
+ group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
494
+ ttl="",
495
+ )
496
+
497
+ log(
498
+ DEBUG,
499
+ "[Stage 2] Forwarding encrypted key shares to %s clients.",
500
+ len(state.active_node_ids),
501
+ )
502
+ msgs = driver.send_and_receive(
503
+ [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
504
+ )
505
+ state.active_node_ids = {
506
+ msg.metadata.src_node_id for msg in msgs if not msg.has_error()
507
+ }
508
+ log(
509
+ DEBUG,
510
+ "[Stage 2] Received masked vectors from %s clients.",
511
+ len(state.active_node_ids),
512
+ )
513
+
514
+ # Clear cache
515
+ del state.forward_ciphertexts, state.forward_srcs, state.nid_to_fitins
516
+
517
+ # Sum collected masked vectors and compute active/dead node IDs
518
+ masked_vector = None
519
+ for msg in msgs:
520
+ res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
521
+ bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
522
+ client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
523
+ if masked_vector is None:
524
+ masked_vector = client_masked_vec
525
+ else:
526
+ masked_vector = parameters_addition(masked_vector, client_masked_vec)
527
+ if masked_vector is not None:
528
+ masked_vector = parameters_mod(masked_vector, state.mod_range)
529
+ state.aggregate_ndarrays = masked_vector
530
+
531
+ return self._check_threshold(state)
532
+
533
+ def unmask_stage( # pylint: disable=R0912, R0914, R0915
534
+ self, driver: Driver, context: LegacyContext, state: WorkflowState
535
+ ) -> bool:
536
+ """Execute the 'unmask' stage."""
537
+ cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
538
+ current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
539
+
540
+ # Construct active node IDs and dead node IDs
541
+ active_nids = state.active_node_ids
542
+ dead_nids = state.sampled_node_ids - active_nids
543
+
544
+ # Send secure IDs of active and dead clients and collect key shares from clients
545
+ def make(nid: int) -> Message:
546
+ neighbours = state.nid_to_neighbours[nid]
547
+ cfgs_dict = {
548
+ Key.STAGE: Stage.UNMASK,
549
+ Key.ACTIVE_NODE_ID_LIST: list(neighbours & active_nids),
550
+ Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
551
+ }
552
+ cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
553
+ content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
554
+ return driver.create_message(
555
+ content=content,
556
+ message_type=MessageType.TRAIN,
557
+ dst_node_id=nid,
558
+ group_id=str(current_round),
559
+ ttl="",
560
+ )
561
+
562
+ log(
563
+ DEBUG,
564
+ "[Stage 3] Requesting key shares from %s clients to remove masks.",
565
+ len(state.active_node_ids),
566
+ )
567
+ msgs = driver.send_and_receive(
568
+ [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
569
+ )
570
+ state.active_node_ids = {
571
+ msg.metadata.src_node_id for msg in msgs if not msg.has_error()
572
+ }
573
+ log(
574
+ DEBUG,
575
+ "[Stage 3] Received key shares from %s clients.",
576
+ len(state.active_node_ids),
577
+ )
578
+
579
+ # Build collected shares dict
580
+ collected_shares_dict: Dict[int, List[bytes]] = {}
581
+ for nid in state.sampled_node_ids:
582
+ collected_shares_dict[nid] = []
583
+ for msg in msgs:
584
+ res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
585
+ nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
586
+ shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
587
+ for owner_nid, share in zip(nids, shares):
588
+ collected_shares_dict[owner_nid].append(share)
589
+
590
+ # Remove masks for every active client after collect_masked_vectors stage
591
+ masked_vector = state.aggregate_ndarrays
592
+ del state.aggregate_ndarrays
593
+ for nid, share_list in collected_shares_dict.items():
594
+ if len(share_list) < state.threshold:
595
+ log(
596
+ ERROR, "Not enough shares to recover secret in unmask vectors stage"
597
+ )
598
+ return False
599
+ secret = combine_shares(share_list)
600
+ if nid in active_nids:
601
+ # The seed for PRG is the private mask seed of an active client.
602
+ private_mask = pseudo_rand_gen(
603
+ secret, state.mod_range, get_parameters_shape(masked_vector)
604
+ )
605
+ masked_vector = parameters_subtraction(masked_vector, private_mask)
606
+ else:
607
+ # The seed for PRG is the secret key 1 of a dropped client.
608
+ neighbours = state.nid_to_neighbours[nid]
609
+ neighbours.remove(nid)
610
+
611
+ for neighbor_nid in neighbours:
612
+ shared_key = generate_shared_key(
613
+ bytes_to_private_key(secret),
614
+ bytes_to_public_key(state.nid_to_publickeys[neighbor_nid][0]),
615
+ )
616
+ pairwise_mask = pseudo_rand_gen(
617
+ shared_key, state.mod_range, get_parameters_shape(masked_vector)
618
+ )
619
+ if nid > neighbor_nid:
620
+ masked_vector = parameters_addition(
621
+ masked_vector, pairwise_mask
622
+ )
623
+ else:
624
+ masked_vector = parameters_subtraction(
625
+ masked_vector, pairwise_mask
626
+ )
627
+ recon_parameters = parameters_mod(masked_vector, state.mod_range)
628
+ q_total_ratio, recon_parameters = factor_extract(recon_parameters)
629
+ inv_dq_total_ratio = state.quantization_range / q_total_ratio
630
+ # recon_parameters = parameters_divide(recon_parameters, total_weights_factor)
631
+ aggregated_vector = dequantize(
632
+ recon_parameters,
633
+ state.clipping_range,
634
+ state.quantization_range,
635
+ )
636
+ offset = -(len(active_nids) - 1) * state.clipping_range
637
+ for vec in aggregated_vector:
638
+ vec += offset
639
+ vec *= inv_dq_total_ratio
640
+ state.aggregate_ndarrays = aggregated_vector
641
+
642
+ # No exception/failure handling currently
643
+ log(
644
+ INFO,
645
+ "aggregate_fit: received %s results and %s failures",
646
+ 1,
647
+ 0,
648
+ )
649
+
650
+ final_fitres = FitRes(
651
+ status=Status(code=Code.OK, message=""),
652
+ parameters=ndarrays_to_parameters(aggregated_vector),
653
+ num_examples=round(state.max_weight / inv_dq_total_ratio),
654
+ metrics={},
655
+ )
656
+ empty_proxy = DriverClientProxy(
657
+ 0,
658
+ driver.grpc_driver, # type: ignore
659
+ False,
660
+ driver.run_id, # type: ignore
661
+ )
662
+ aggregated_result = context.strategy.aggregate_fit(
663
+ current_round, [(empty_proxy, final_fitres)], []
664
+ )
665
+ parameters_aggregated, metrics_aggregated = aggregated_result
666
+
667
+ # Update the parameters and write history
668
+ if parameters_aggregated:
669
+ paramsrecord = compat.parameters_to_parametersrecord(
670
+ parameters_aggregated, True
671
+ )
672
+ context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
673
+ context.history.add_metrics_distributed_fit(
674
+ server_round=current_round, metrics=metrics_aggregated
675
+ )
676
+ return True
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.8.0.dev20240309
3
+ Version: 1.8.0.dev20240311
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0