flwr-nightly 1.8.0.dev20240309__py3-none-any.whl → 1.8.0.dev20240311__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. flwr/cli/flower_toml.py +4 -48
  2. flwr/cli/new/new.py +6 -3
  3. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -3
  4. flwr/cli/new/templates/app/pyproject.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +2 -2
  6. flwr/cli/utils.py +14 -1
  7. flwr/client/app.py +39 -5
  8. flwr/client/client_app.py +1 -47
  9. flwr/client/mod/__init__.py +2 -1
  10. flwr/client/mod/secure_aggregation/__init__.py +2 -0
  11. flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
  12. flwr/client/mod/secure_aggregation/secaggplus_mod.py +73 -57
  13. flwr/common/grpc.py +3 -3
  14. flwr/common/logger.py +78 -15
  15. flwr/common/object_ref.py +140 -0
  16. flwr/common/secure_aggregation/ndarrays_arithmetic.py +5 -5
  17. flwr/common/secure_aggregation/secaggplus_constants.py +7 -6
  18. flwr/common/secure_aggregation/secaggplus_utils.py +15 -15
  19. flwr/server/compat/app.py +2 -1
  20. flwr/server/driver/grpc_driver.py +4 -4
  21. flwr/server/history.py +22 -15
  22. flwr/server/run_serverapp.py +22 -4
  23. flwr/server/server.py +27 -23
  24. flwr/server/server_app.py +1 -47
  25. flwr/server/server_config.py +9 -0
  26. flwr/server/strategy/fedavg.py +2 -0
  27. flwr/server/superlink/fleet/vce/vce_api.py +9 -2
  28. flwr/server/superlink/state/in_memory_state.py +34 -32
  29. flwr/server/workflow/__init__.py +3 -0
  30. flwr/server/workflow/constant.py +32 -0
  31. flwr/server/workflow/default_workflows.py +52 -57
  32. flwr/server/workflow/secure_aggregation/__init__.py +24 -0
  33. flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
  34. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +676 -0
  35. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/METADATA +1 -1
  36. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/RECORD +39 -33
  37. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/LICENSE +0 -0
  38. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/WHEEL +0 -0
  39. {flwr_nightly-1.8.0.dev20240309.dist-info → flwr_nightly-1.8.0.dev20240311.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,676 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Workflow for the SecAgg+ protocol."""
16
+
17
+
18
+ import random
19
+ from dataclasses import dataclass, field
20
+ from logging import DEBUG, ERROR, INFO, WARN
21
+ from typing import Dict, List, Optional, Set, Union, cast
22
+
23
+ import flwr.common.recordset_compat as compat
24
+ from flwr.common import (
25
+ Code,
26
+ ConfigsRecord,
27
+ Context,
28
+ FitRes,
29
+ Message,
30
+ MessageType,
31
+ NDArrays,
32
+ RecordSet,
33
+ Status,
34
+ bytes_to_ndarray,
35
+ log,
36
+ ndarrays_to_parameters,
37
+ )
38
+ from flwr.common.secure_aggregation.crypto.shamir import combine_shares
39
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
40
+ bytes_to_private_key,
41
+ bytes_to_public_key,
42
+ generate_shared_key,
43
+ )
44
+ from flwr.common.secure_aggregation.ndarrays_arithmetic import (
45
+ factor_extract,
46
+ get_parameters_shape,
47
+ parameters_addition,
48
+ parameters_mod,
49
+ parameters_subtraction,
50
+ )
51
+ from flwr.common.secure_aggregation.quantization import dequantize
52
+ from flwr.common.secure_aggregation.secaggplus_constants import (
53
+ RECORD_KEY_CONFIGS,
54
+ Key,
55
+ Stage,
56
+ )
57
+ from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
58
+ from flwr.server.compat.driver_client_proxy import DriverClientProxy
59
+ from flwr.server.compat.legacy_context import LegacyContext
60
+ from flwr.server.driver import Driver
61
+
62
+ from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
63
+ from ..constant import Key as WorkflowKey
64
+
65
+
66
+ @dataclass
67
+ class WorkflowState: # pylint: disable=R0902
68
+ """The state of the SecAgg+ protocol."""
69
+
70
+ nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict)
71
+ sampled_node_ids: Set[int] = field(default_factory=set)
72
+ active_node_ids: Set[int] = field(default_factory=set)
73
+ num_shares: int = 0
74
+ threshold: int = 0
75
+ clipping_range: float = 0.0
76
+ quantization_range: int = 0
77
+ mod_range: int = 0
78
+ max_weight: float = 0.0
79
+ nid_to_neighbours: Dict[int, Set[int]] = field(default_factory=dict)
80
+ nid_to_publickeys: Dict[int, List[bytes]] = field(default_factory=dict)
81
+ forward_srcs: Dict[int, List[int]] = field(default_factory=dict)
82
+ forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
83
+ aggregate_ndarrays: NDArrays = field(default_factory=list)
84
+
85
+
86
+ class SecAggPlusWorkflow:
87
+ """The workflow for the SecAgg+ protocol.
88
+
89
+ The SecAgg+ protocol ensures the secure summation of integer vectors owned by
90
+ multiple parties, without accessing any individual integer vector. This workflow
91
+ allows the server to compute the weighted average of model parameters across all
92
+ clients, ensuring individual contributions remain private. This is achieved by
93
+ clients sending both, a weighting factor and a weighted version of the locally
94
+ updated parameters, both of which are masked for privacy. Specifically, each
95
+ client uploads "[w, w * params]" with masks, where weighting factor 'w' is the
96
+ number of examples ('num_examples') and 'params' represents the model parameters
97
+ ('parameters') from the client's `FitRes`. The server then aggregates these
98
+ contributions to compute the weighted average of model parameters.
99
+
100
+ The protocol involves four main stages:
101
+ - 'setup': Send SecAgg+ configuration to clients and collect their public keys.
102
+ - 'share keys': Broadcast public keys among clients and collect encrypted secret
103
+ key shares.
104
+ - 'collect masked vectors': Forward encrypted secret key shares to target clients
105
+ and collect masked model parameters.
106
+ - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
107
+
108
+ Only the aggregated model parameters are exposed and passed to
109
+ `Strategy.aggregate_fit`, ensuring individual data privacy.
110
+
111
+ Parameters
112
+ ----------
113
+ num_shares : Union[int, float]
114
+ The number of shares into which each client's private key is split under
115
+ the SecAgg+ protocol. If specified as a float, it represents the proportion
116
+ of all selected clients, and the number of shares will be set dynamically in
117
+ the run time. A private key can be reconstructed from these shares, allowing
118
+ for the secure aggregation of model updates. Each client sends one share to
119
+ each of its neighbors while retaining one.
120
+ reconstruction_threshold : Union[int, float]
121
+ The minimum number of shares required to reconstruct a client's private key,
122
+ or, if specified as a float, it represents the proportion of the total number
123
+ of shares needed for reconstruction. This threshold ensures privacy by allowing
124
+ for the recovery of contributions from dropped clients during aggregation,
125
+ without compromising individual client data.
126
+ max_weight : Optional[float] (default: 1000.0)
127
+ The maximum value of the weight that can be assigned to any single client's
128
+ update during the weighted average calculation on the server side, e.g., in the
129
+ FedAvg algorithm.
130
+ clipping_range : float, optional (default: 8.0)
131
+ The range within which model parameters are clipped before quantization.
132
+ This parameter ensures each model parameter is bounded within
133
+ [-clipping_range, clipping_range], facilitating quantization.
134
+ quantization_range : int, optional (default: 4194304, this equals 2**22)
135
+ The size of the range into which floating-point model parameters are quantized,
136
+ mapping each parameter to an integer in [0, quantization_range-1]. This
137
+ facilitates cryptographic operations on the model updates.
138
+ modulus_range : int, optional (default: 4294967296, this equals 2**32)
139
+ The range of values from which random mask entries are uniformly sampled
140
+ ([0, modulus_range-1]). `modulus_range` must be less than 4294967296.
141
+ Please use 2**n values for `modulus_range` to prevent overflow issues.
142
+ timeout : Optional[float] (default: None)
143
+ The timeout duration in seconds. If specified, the workflow will wait for
144
+ replies for this duration each time. If `None`, there is no time limit and
145
+ the workflow will wait until replies for all messages are received.
146
+
147
+ Notes
148
+ -----
149
+ - Generally, higher `num_shares` means more robust to dropouts while increasing the
150
+ computational costs; higher `reconstruction_threshold` means better privacy
151
+ guarantees but less tolerance to dropouts.
152
+ - Too large `max_weight` may compromise the precision of the quantization.
153
+ - `modulus_range` must be 2**n and larger than `quantization_range`.
154
+ - When `num_shares` is a float, it is interpreted as the proportion of all selected
155
+ clients, and hence the number of shares will be determined in the runtime. This
156
+ allows for dynamic adjustment based on the total number of participating clients.
157
+ - Similarly, when `reconstruction_threshold` is a float, it is interpreted as the
158
+ proportion of the number of shares needed for the reconstruction of a private key.
159
+ This feature enables flexibility in setting the security threshold relative to the
160
+ number of distributed shares.
161
+ - `num_shares`, `reconstruction_threshold`, and the quantization parameters
162
+ (`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in
163
+ balancing privacy, robustness, and efficiency within the SecAgg+ protocol.
164
+ """
165
+
166
+ def __init__( # pylint: disable=R0913
167
+ self,
168
+ num_shares: Union[int, float],
169
+ reconstruction_threshold: Union[int, float],
170
+ *,
171
+ max_weight: float = 1000.0,
172
+ clipping_range: float = 8.0,
173
+ quantization_range: int = 4194304,
174
+ modulus_range: int = 4294967296,
175
+ timeout: Optional[float] = None,
176
+ ) -> None:
177
+ self.num_shares = num_shares
178
+ self.reconstruction_threshold = reconstruction_threshold
179
+ self.max_weight = max_weight
180
+ self.clipping_range = clipping_range
181
+ self.quantization_range = quantization_range
182
+ self.modulus_range = modulus_range
183
+ self.timeout = timeout
184
+
185
+ self._check_init_params()
186
+
187
+ def __call__(self, driver: Driver, context: Context) -> None:
188
+ """Run the SecAgg+ protocol."""
189
+ if not isinstance(context, LegacyContext):
190
+ raise TypeError(
191
+ f"Expect a LegacyContext, but get {type(context).__name__}."
192
+ )
193
+ state = WorkflowState()
194
+
195
+ steps = (
196
+ self.setup_stage,
197
+ self.share_keys_stage,
198
+ self.collect_masked_vectors_stage,
199
+ self.unmask_stage,
200
+ )
201
+ log(INFO, "Secure aggregation commencing.")
202
+ for step in steps:
203
+ if not step(driver, context, state):
204
+ log(INFO, "Secure aggregation halted.")
205
+ return
206
+ log(INFO, "Secure aggregation completed.")
207
+
208
+ def _check_init_params(self) -> None: # pylint: disable=R0912
209
+ # Check `num_shares`
210
+ if not isinstance(self.num_shares, (int, float)):
211
+ raise TypeError("`num_shares` must be of type int or float.")
212
+ if isinstance(self.num_shares, int):
213
+ if self.num_shares == 1:
214
+ self.num_shares = 1.0
215
+ elif self.num_shares <= 2:
216
+ raise ValueError("`num_shares` as an integer must be greater than 2.")
217
+ elif self.num_shares > self.modulus_range / self.quantization_range:
218
+ log(
219
+ WARN,
220
+ "A `num_shares` larger than `modulus_range / quantization_range` "
221
+ "will potentially cause overflow when computing the aggregated "
222
+ "model parameters.",
223
+ )
224
+ elif self.num_shares <= 0:
225
+ raise ValueError("`num_shares` as a float must be greater than 0.")
226
+
227
+ # Check `reconstruction_threshold`
228
+ if not isinstance(self.reconstruction_threshold, (int, float)):
229
+ raise TypeError("`reconstruction_threshold` must be of type int or float.")
230
+ if isinstance(self.reconstruction_threshold, int):
231
+ if self.reconstruction_threshold == 1:
232
+ self.reconstruction_threshold = 1.0
233
+ elif isinstance(self.num_shares, int):
234
+ if self.reconstruction_threshold >= self.num_shares:
235
+ raise ValueError(
236
+ "`reconstruction_threshold` must be less than `num_shares`."
237
+ )
238
+ else:
239
+ if not 0 < self.reconstruction_threshold <= 1:
240
+ raise ValueError(
241
+ "If `reconstruction_threshold` is a float, "
242
+ "it must be greater than 0 and less than or equal to 1."
243
+ )
244
+
245
+ # Check `max_weight`
246
+ if self.max_weight <= 0:
247
+ raise ValueError("`max_weight` must be greater than 0.")
248
+
249
+ # Check `quantization_range`
250
+ if self.quantization_range <= 0:
251
+ raise ValueError("`quantization_range` must be greater than 0.")
252
+
253
+ # Check `quantization_range`
254
+ if not isinstance(self.quantization_range, int) or self.quantization_range <= 0:
255
+ raise ValueError(
256
+ "`quantization_range` must be an integer and greater than 0."
257
+ )
258
+
259
+ # Check `modulus_range`
260
+ if (
261
+ not isinstance(self.modulus_range, int)
262
+ or self.modulus_range <= self.quantization_range
263
+ ):
264
+ raise ValueError(
265
+ "`modulus_range` must be an integer and "
266
+ "greater than `quantization_range`."
267
+ )
268
+ if bin(self.modulus_range).count("1") != 1:
269
+ raise ValueError("`modulus_range` must be a power of 2.")
270
+
271
+ def _check_threshold(self, state: WorkflowState) -> bool:
272
+ for node_id in state.sampled_node_ids:
273
+ active_neighbors = state.nid_to_neighbours[node_id] & state.active_node_ids
274
+ if len(active_neighbors) < state.threshold:
275
+ log(ERROR, "Insufficient available nodes.")
276
+ return False
277
+ return True
278
+
279
+ def setup_stage( # pylint: disable=R0912, R0914, R0915
280
+ self, driver: Driver, context: LegacyContext, state: WorkflowState
281
+ ) -> bool:
282
+ """Execute the 'setup' stage."""
283
+ # Obtain fit instructions
284
+ cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
285
+ current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
286
+ parameters = compat.parametersrecord_to_parameters(
287
+ context.state.parameters_records[MAIN_PARAMS_RECORD],
288
+ keep_input=True,
289
+ )
290
+ proxy_fitins_lst = context.strategy.configure_fit(
291
+ current_round, parameters, context.client_manager
292
+ )
293
+ if not proxy_fitins_lst:
294
+ log(INFO, "configure_fit: no clients selected, cancel")
295
+ return False
296
+ log(
297
+ INFO,
298
+ "configure_fit: strategy sampled %s clients (out of %s)",
299
+ len(proxy_fitins_lst),
300
+ context.client_manager.num_available(),
301
+ )
302
+
303
+ state.nid_to_fitins = {
304
+ proxy.node_id: compat.fitins_to_recordset(fitins, False)
305
+ for proxy, fitins in proxy_fitins_lst
306
+ }
307
+
308
+ # Protocol config
309
+ sampled_node_ids = list(state.nid_to_fitins.keys())
310
+ num_samples = len(sampled_node_ids)
311
+ if num_samples < 2:
312
+ log(ERROR, "The number of samples should be greater than 1.")
313
+ return False
314
+ if isinstance(self.num_shares, float):
315
+ state.num_shares = round(self.num_shares * num_samples)
316
+ # If even
317
+ if state.num_shares < num_samples and state.num_shares & 1 == 0:
318
+ state.num_shares += 1
319
+ # If too small
320
+ if state.num_shares <= 2:
321
+ state.num_shares = num_samples
322
+ else:
323
+ state.num_shares = self.num_shares
324
+ if isinstance(self.reconstruction_threshold, float):
325
+ state.threshold = round(self.reconstruction_threshold * state.num_shares)
326
+ # Avoid too small threshold
327
+ state.threshold = max(state.threshold, 2)
328
+ else:
329
+ state.threshold = self.reconstruction_threshold
330
+ state.active_node_ids = set(sampled_node_ids)
331
+ state.clipping_range = self.clipping_range
332
+ state.quantization_range = self.quantization_range
333
+ state.mod_range = self.modulus_range
334
+ state.max_weight = self.max_weight
335
+ sa_params_dict = {
336
+ Key.STAGE: Stage.SETUP,
337
+ Key.SAMPLE_NUMBER: num_samples,
338
+ Key.SHARE_NUMBER: state.num_shares,
339
+ Key.THRESHOLD: state.threshold,
340
+ Key.CLIPPING_RANGE: state.clipping_range,
341
+ Key.TARGET_RANGE: state.quantization_range,
342
+ Key.MOD_RANGE: state.mod_range,
343
+ Key.MAX_WEIGHT: state.max_weight,
344
+ }
345
+
346
+ # The number of shares should better be odd in the SecAgg+ protocol.
347
+ if num_samples != state.num_shares and state.num_shares & 1 == 0:
348
+ log(WARN, "Number of shares in the SecAgg+ protocol should be odd.")
349
+ state.num_shares += 1
350
+
351
+ # Shuffle node IDs
352
+ random.shuffle(sampled_node_ids)
353
+ # Build neighbour relations (node ID -> secure IDs of neighbours)
354
+ half_share = state.num_shares >> 1
355
+ state.nid_to_neighbours = {
356
+ nid: {
357
+ sampled_node_ids[(idx + offset) % num_samples]
358
+ for offset in range(-half_share, half_share + 1)
359
+ }
360
+ for idx, nid in enumerate(sampled_node_ids)
361
+ }
362
+
363
+ state.sampled_node_ids = state.active_node_ids
364
+
365
+ # Send setup configuration to clients
366
+ cfgs_record = ConfigsRecord(sa_params_dict) # type: ignore
367
+ content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
368
+
369
+ def make(nid: int) -> Message:
370
+ return driver.create_message(
371
+ content=content,
372
+ message_type=MessageType.TRAIN,
373
+ dst_node_id=nid,
374
+ group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
375
+ ttl="",
376
+ )
377
+
378
+ log(
379
+ DEBUG,
380
+ "[Stage 0] Sending configurations to %s clients.",
381
+ len(state.active_node_ids),
382
+ )
383
+ msgs = driver.send_and_receive(
384
+ [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
385
+ )
386
+ state.active_node_ids = {
387
+ msg.metadata.src_node_id for msg in msgs if not msg.has_error()
388
+ }
389
+ log(
390
+ DEBUG,
391
+ "[Stage 0] Received public keys from %s clients.",
392
+ len(state.active_node_ids),
393
+ )
394
+
395
+ for msg in msgs:
396
+ if msg.has_error():
397
+ continue
398
+ key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
399
+ node_id = msg.metadata.src_node_id
400
+ pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2]
401
+ state.nid_to_publickeys[node_id] = [cast(bytes, pk1), cast(bytes, pk2)]
402
+
403
+ return self._check_threshold(state)
404
+
405
+ def share_keys_stage( # pylint: disable=R0914
406
+ self, driver: Driver, context: LegacyContext, state: WorkflowState
407
+ ) -> bool:
408
+ """Execute the 'share keys' stage."""
409
+ cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
410
+
411
+ def make(nid: int) -> Message:
412
+ neighbours = state.nid_to_neighbours[nid] & state.active_node_ids
413
+ cfgs_record = ConfigsRecord(
414
+ {str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
415
+ )
416
+ cfgs_record[Key.STAGE] = Stage.SHARE_KEYS
417
+ content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
418
+ return driver.create_message(
419
+ content=content,
420
+ message_type=MessageType.TRAIN,
421
+ dst_node_id=nid,
422
+ group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
423
+ ttl="",
424
+ )
425
+
426
+ # Broadcast public keys to clients and receive secret key shares
427
+ log(
428
+ DEBUG,
429
+ "[Stage 1] Forwarding public keys to %s clients.",
430
+ len(state.active_node_ids),
431
+ )
432
+ msgs = driver.send_and_receive(
433
+ [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
434
+ )
435
+ state.active_node_ids = {
436
+ msg.metadata.src_node_id for msg in msgs if not msg.has_error()
437
+ }
438
+ log(
439
+ DEBUG,
440
+ "[Stage 1] Received encrypted key shares from %s clients.",
441
+ len(state.active_node_ids),
442
+ )
443
+
444
+ # Build forward packet list dictionary
445
+ srcs: List[int] = []
446
+ dsts: List[int] = []
447
+ ciphertexts: List[bytes] = []
448
+ fwd_ciphertexts: Dict[int, List[bytes]] = {
449
+ nid: [] for nid in state.active_node_ids
450
+ } # dest node ID -> list of ciphertexts
451
+ fwd_srcs: Dict[int, List[int]] = {
452
+ nid: [] for nid in state.active_node_ids
453
+ } # dest node ID -> list of src node IDs
454
+ for msg in msgs:
455
+ node_id = msg.metadata.src_node_id
456
+ res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
457
+ dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
458
+ ctxt_lst = cast(List[bytes], res_dict[Key.CIPHERTEXT_LIST])
459
+ srcs += [node_id] * len(dst_lst)
460
+ dsts += dst_lst
461
+ ciphertexts += ctxt_lst
462
+
463
+ for src, dst, ciphertext in zip(srcs, dsts, ciphertexts):
464
+ if dst in fwd_ciphertexts:
465
+ fwd_ciphertexts[dst].append(ciphertext)
466
+ fwd_srcs[dst].append(src)
467
+
468
+ state.forward_srcs = fwd_srcs
469
+ state.forward_ciphertexts = fwd_ciphertexts
470
+
471
+ return self._check_threshold(state)
472
+
473
+ def collect_masked_vectors_stage(
474
+ self, driver: Driver, context: LegacyContext, state: WorkflowState
475
+ ) -> bool:
476
+ """Execute the 'collect masked vectors' stage."""
477
+ cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
478
+
479
+ # Send secret key shares to clients (plus FitIns) and collect masked vectors
480
+ def make(nid: int) -> Message:
481
+ cfgs_dict = {
482
+ Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
483
+ Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
484
+ Key.SOURCE_LIST: state.forward_srcs[nid],
485
+ }
486
+ cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
487
+ content = state.nid_to_fitins[nid]
488
+ content.configs_records[RECORD_KEY_CONFIGS] = cfgs_record
489
+ return driver.create_message(
490
+ content=content,
491
+ message_type=MessageType.TRAIN,
492
+ dst_node_id=nid,
493
+ group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
494
+ ttl="",
495
+ )
496
+
497
+ log(
498
+ DEBUG,
499
+ "[Stage 2] Forwarding encrypted key shares to %s clients.",
500
+ len(state.active_node_ids),
501
+ )
502
+ msgs = driver.send_and_receive(
503
+ [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
504
+ )
505
+ state.active_node_ids = {
506
+ msg.metadata.src_node_id for msg in msgs if not msg.has_error()
507
+ }
508
+ log(
509
+ DEBUG,
510
+ "[Stage 2] Received masked vectors from %s clients.",
511
+ len(state.active_node_ids),
512
+ )
513
+
514
+ # Clear cache
515
+ del state.forward_ciphertexts, state.forward_srcs, state.nid_to_fitins
516
+
517
+ # Sum collected masked vectors and compute active/dead node IDs
518
+ masked_vector = None
519
+ for msg in msgs:
520
+ res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
521
+ bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
522
+ client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
523
+ if masked_vector is None:
524
+ masked_vector = client_masked_vec
525
+ else:
526
+ masked_vector = parameters_addition(masked_vector, client_masked_vec)
527
+ if masked_vector is not None:
528
+ masked_vector = parameters_mod(masked_vector, state.mod_range)
529
+ state.aggregate_ndarrays = masked_vector
530
+
531
+ return self._check_threshold(state)
532
+
533
+ def unmask_stage( # pylint: disable=R0912, R0914, R0915
534
+ self, driver: Driver, context: LegacyContext, state: WorkflowState
535
+ ) -> bool:
536
+ """Execute the 'unmask' stage."""
537
+ cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
538
+ current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
539
+
540
+ # Construct active node IDs and dead node IDs
541
+ active_nids = state.active_node_ids
542
+ dead_nids = state.sampled_node_ids - active_nids
543
+
544
+ # Send secure IDs of active and dead clients and collect key shares from clients
545
+ def make(nid: int) -> Message:
546
+ neighbours = state.nid_to_neighbours[nid]
547
+ cfgs_dict = {
548
+ Key.STAGE: Stage.UNMASK,
549
+ Key.ACTIVE_NODE_ID_LIST: list(neighbours & active_nids),
550
+ Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
551
+ }
552
+ cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
553
+ content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
554
+ return driver.create_message(
555
+ content=content,
556
+ message_type=MessageType.TRAIN,
557
+ dst_node_id=nid,
558
+ group_id=str(current_round),
559
+ ttl="",
560
+ )
561
+
562
+ log(
563
+ DEBUG,
564
+ "[Stage 3] Requesting key shares from %s clients to remove masks.",
565
+ len(state.active_node_ids),
566
+ )
567
+ msgs = driver.send_and_receive(
568
+ [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
569
+ )
570
+ state.active_node_ids = {
571
+ msg.metadata.src_node_id for msg in msgs if not msg.has_error()
572
+ }
573
+ log(
574
+ DEBUG,
575
+ "[Stage 3] Received key shares from %s clients.",
576
+ len(state.active_node_ids),
577
+ )
578
+
579
+ # Build collected shares dict
580
+ collected_shares_dict: Dict[int, List[bytes]] = {}
581
+ for nid in state.sampled_node_ids:
582
+ collected_shares_dict[nid] = []
583
+ for msg in msgs:
584
+ res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
585
+ nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
586
+ shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
587
+ for owner_nid, share in zip(nids, shares):
588
+ collected_shares_dict[owner_nid].append(share)
589
+
590
+ # Remove masks for every active client after collect_masked_vectors stage
591
+ masked_vector = state.aggregate_ndarrays
592
+ del state.aggregate_ndarrays
593
+ for nid, share_list in collected_shares_dict.items():
594
+ if len(share_list) < state.threshold:
595
+ log(
596
+ ERROR, "Not enough shares to recover secret in unmask vectors stage"
597
+ )
598
+ return False
599
+ secret = combine_shares(share_list)
600
+ if nid in active_nids:
601
+ # The seed for PRG is the private mask seed of an active client.
602
+ private_mask = pseudo_rand_gen(
603
+ secret, state.mod_range, get_parameters_shape(masked_vector)
604
+ )
605
+ masked_vector = parameters_subtraction(masked_vector, private_mask)
606
+ else:
607
+ # The seed for PRG is the secret key 1 of a dropped client.
608
+ neighbours = state.nid_to_neighbours[nid]
609
+ neighbours.remove(nid)
610
+
611
+ for neighbor_nid in neighbours:
612
+ shared_key = generate_shared_key(
613
+ bytes_to_private_key(secret),
614
+ bytes_to_public_key(state.nid_to_publickeys[neighbor_nid][0]),
615
+ )
616
+ pairwise_mask = pseudo_rand_gen(
617
+ shared_key, state.mod_range, get_parameters_shape(masked_vector)
618
+ )
619
+ if nid > neighbor_nid:
620
+ masked_vector = parameters_addition(
621
+ masked_vector, pairwise_mask
622
+ )
623
+ else:
624
+ masked_vector = parameters_subtraction(
625
+ masked_vector, pairwise_mask
626
+ )
627
+ recon_parameters = parameters_mod(masked_vector, state.mod_range)
628
+ q_total_ratio, recon_parameters = factor_extract(recon_parameters)
629
+ inv_dq_total_ratio = state.quantization_range / q_total_ratio
630
+ # recon_parameters = parameters_divide(recon_parameters, total_weights_factor)
631
+ aggregated_vector = dequantize(
632
+ recon_parameters,
633
+ state.clipping_range,
634
+ state.quantization_range,
635
+ )
636
+ offset = -(len(active_nids) - 1) * state.clipping_range
637
+ for vec in aggregated_vector:
638
+ vec += offset
639
+ vec *= inv_dq_total_ratio
640
+ state.aggregate_ndarrays = aggregated_vector
641
+
642
+ # No exception/failure handling currently
643
+ log(
644
+ INFO,
645
+ "aggregate_fit: received %s results and %s failures",
646
+ 1,
647
+ 0,
648
+ )
649
+
650
+ final_fitres = FitRes(
651
+ status=Status(code=Code.OK, message=""),
652
+ parameters=ndarrays_to_parameters(aggregated_vector),
653
+ num_examples=round(state.max_weight / inv_dq_total_ratio),
654
+ metrics={},
655
+ )
656
+ empty_proxy = DriverClientProxy(
657
+ 0,
658
+ driver.grpc_driver, # type: ignore
659
+ False,
660
+ driver.run_id, # type: ignore
661
+ )
662
+ aggregated_result = context.strategy.aggregate_fit(
663
+ current_round, [(empty_proxy, final_fitres)], []
664
+ )
665
+ parameters_aggregated, metrics_aggregated = aggregated_result
666
+
667
+ # Update the parameters and write history
668
+ if parameters_aggregated:
669
+ paramsrecord = compat.parameters_to_parametersrecord(
670
+ parameters_aggregated, True
671
+ )
672
+ context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
673
+ context.history.add_metrics_distributed_fit(
674
+ server_round=current_round, metrics=metrics_aggregated
675
+ )
676
+ return True
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.8.0.dev20240309
3
+ Version: 1.8.0.dev20240311
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0