flwr-nightly 1.8.0.dev20240310__py3-none-any.whl → 1.8.0.dev20240312__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,8 +15,9 @@
15
15
  """Legacy default workflows."""
16
16
 
17
17
 
18
+ import io
18
19
  import timeit
19
- from logging import DEBUG, INFO
20
+ from logging import INFO
20
21
  from typing import Optional, cast
21
22
 
22
23
  import flwr.common.recordset_compat as compat
@@ -58,16 +59,18 @@ class DefaultWorkflow:
58
59
  )
59
60
 
60
61
  # Initialize parameters
62
+ log(INFO, "[INIT]")
61
63
  default_init_params_workflow(driver, context)
62
64
 
63
65
  # Run federated learning for num_rounds
64
- log(INFO, "FL starting")
65
66
  start_time = timeit.default_timer()
66
67
  cfg = ConfigsRecord()
67
68
  cfg[Key.START_TIME] = start_time
68
69
  context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg
69
70
 
70
71
  for current_round in range(1, context.config.num_rounds + 1):
72
+ log(INFO, "")
73
+ log(INFO, "[ROUND %s]", current_round)
71
74
  cfg[Key.CURRENT_ROUND] = current_round
72
75
 
73
76
  # Fit round
@@ -79,22 +82,19 @@ class DefaultWorkflow:
79
82
  # Evaluate round
80
83
  self.evaluate_workflow(driver, context)
81
84
 
82
- # Bookkeeping
85
+ # Bookkeeping and log results
83
86
  end_time = timeit.default_timer()
84
87
  elapsed = end_time - start_time
85
- log(INFO, "FL finished in %s", elapsed)
86
-
87
- # Log results
88
88
  hist = context.history
89
- log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
90
- log(
91
- INFO,
92
- "app_fit: metrics_distributed_fit %s",
93
- str(hist.metrics_distributed_fit),
94
- )
95
- log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed))
96
- log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
97
- log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized))
89
+ log(INFO, "")
90
+ log(INFO, "[SUMMARY]")
91
+ log(INFO, "Run finished %s rounds in %.2fs", context.config.num_rounds, elapsed)
92
+ for idx, line in enumerate(io.StringIO(str(hist))):
93
+ if idx == 0:
94
+ log(INFO, "%s", line.strip("\n"))
95
+ else:
96
+ log(INFO, "\t%s", line.strip("\n"))
97
+ log(INFO, "")
98
98
 
99
99
  # Terminate the thread
100
100
  f_stop.set()
@@ -107,12 +107,11 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
107
107
  if not isinstance(context, LegacyContext):
108
108
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
109
109
 
110
- log(INFO, "Initializing global parameters")
111
110
  parameters = context.strategy.initialize_parameters(
112
111
  client_manager=context.client_manager
113
112
  )
114
113
  if parameters is not None:
115
- log(INFO, "Using initial parameters provided by strategy")
114
+ log(INFO, "Using initial global parameters provided by strategy")
116
115
  paramsrecord = compat.parameters_to_parametersrecord(
117
116
  parameters, keep_input=True
118
117
  )
@@ -128,7 +127,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
128
127
  content=content,
129
128
  message_type=MessageTypeLegacy.GET_PARAMETERS,
130
129
  dst_node_id=random_client.node_id,
131
- group_id="",
130
+ group_id="0",
132
131
  ttl="",
133
132
  )
134
133
  ]
@@ -140,7 +139,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
140
139
  context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
141
140
 
142
141
  # Evaluate initial parameters
143
- log(INFO, "Evaluating initial parameters")
142
+ log(INFO, "Evaluating initial global parameters")
144
143
  parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
145
144
  res = context.strategy.evaluate(0, parameters=parameters)
146
145
  if res is not None:
@@ -186,7 +185,9 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
186
185
  )
187
186
 
188
187
 
189
- def default_fit_workflow(driver: Driver, context: Context) -> None:
188
+ def default_fit_workflow( # pylint: disable=R0914
189
+ driver: Driver, context: Context
190
+ ) -> None:
190
191
  """Execute the default workflow for a single fit round."""
191
192
  if not isinstance(context, LegacyContext):
192
193
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
@@ -207,12 +208,11 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
207
208
  )
208
209
 
209
210
  if not client_instructions:
210
- log(INFO, "fit_round %s: no clients selected, cancel", current_round)
211
+ log(INFO, "configure_fit: no clients selected, cancel")
211
212
  return
212
213
  log(
213
- DEBUG,
214
- "fit_round %s: strategy sampled %s clients (out of %s)",
215
- current_round,
214
+ INFO,
215
+ "configure_fit: strategy sampled %s clients (out of %s)",
216
216
  len(client_instructions),
217
217
  context.client_manager.num_available(),
218
218
  )
@@ -226,7 +226,7 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
226
226
  content=compat.fitins_to_recordset(fitins, True),
227
227
  message_type=MessageType.TRAIN,
228
228
  dst_node_id=proxy.node_id,
229
- group_id="",
229
+ group_id=str(current_round),
230
230
  ttl="",
231
231
  )
232
232
  for proxy, fitins in client_instructions
@@ -236,14 +236,14 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
236
236
  # collect `fit` results from all clients participating in this round
237
237
  messages = list(driver.send_and_receive(out_messages))
238
238
  del out_messages
239
+ num_failures = len([msg for msg in messages if msg.has_error()])
239
240
 
240
241
  # No exception/failure handling currently
241
242
  log(
242
- DEBUG,
243
- "fit_round %s received %s results and %s failures",
244
- current_round,
245
- len(messages),
246
- 0,
243
+ INFO,
244
+ "aggregate_fit: received %s results and %s failures",
245
+ len(messages) - num_failures,
246
+ num_failures,
247
247
  )
248
248
 
249
249
  # Aggregate training results
@@ -288,12 +288,11 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
288
288
  client_manager=context.client_manager,
289
289
  )
290
290
  if not client_instructions:
291
- log(INFO, "evaluate_round %s: no clients selected, cancel", current_round)
291
+ log(INFO, "configure_evaluate: no clients selected, skipping evaluation")
292
292
  return
293
293
  log(
294
- DEBUG,
295
- "evaluate_round %s: strategy sampled %s clients (out of %s)",
296
- current_round,
294
+ INFO,
295
+ "configure_evaluate: strategy sampled %s clients (out of %s)",
297
296
  len(client_instructions),
298
297
  context.client_manager.num_available(),
299
298
  )
@@ -307,7 +306,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
307
306
  content=compat.evaluateins_to_recordset(evalins, True),
308
307
  message_type=MessageType.EVALUATE,
309
308
  dst_node_id=proxy.node_id,
310
- group_id="",
309
+ group_id=str(current_round),
311
310
  ttl="",
312
311
  )
313
312
  for proxy, evalins in client_instructions
@@ -317,14 +316,14 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
317
316
  # collect `evaluate` results from all clients participating in this round
318
317
  messages = list(driver.send_and_receive(out_messages))
319
318
  del out_messages
319
+ num_failures = len([msg for msg in messages if msg.has_error()])
320
320
 
321
321
  # No exception/failure handling currently
322
322
  log(
323
- DEBUG,
324
- "evaluate_round %s received %s results and %s failures",
325
- current_round,
326
- len(messages),
327
- 0,
323
+ INFO,
324
+ "aggregate_evaluate: received %s results and %s failures",
325
+ len(messages) - num_failures,
326
+ num_failures,
328
327
  )
329
328
 
330
329
  # Aggregate the evaluation results
@@ -15,8 +15,10 @@
15
15
  """Secure Aggregation workflows."""
16
16
 
17
17
 
18
+ from .secagg_workflow import SecAggWorkflow
18
19
  from .secaggplus_workflow import SecAggPlusWorkflow
19
20
 
20
21
  __all__ = [
21
22
  "SecAggPlusWorkflow",
23
+ "SecAggWorkflow",
22
24
  ]
@@ -0,0 +1,112 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Workflow for the SecAgg protocol."""
16
+
17
+
18
+ from typing import Optional, Union
19
+
20
+ from .secaggplus_workflow import SecAggPlusWorkflow
21
+
22
+
23
+ class SecAggWorkflow(SecAggPlusWorkflow):
24
+ """The workflow for the SecAgg protocol.
25
+
26
+ The SecAgg protocol ensures the secure summation of integer vectors owned by
27
+ multiple parties, without accessing any individual integer vector. This workflow
28
+ allows the server to compute the weighted average of model parameters across all
29
+ clients, ensuring individual contributions remain private. This is achieved by
30
+ clients sending both, a weighting factor and a weighted version of the locally
31
+ updated parameters, both of which are masked for privacy. Specifically, each
32
+ client uploads "[w, w * params]" with masks, where weighting factor 'w' is the
33
+ number of examples ('num_examples') and 'params' represents the model parameters
34
+ ('parameters') from the client's `FitRes`. The server then aggregates these
35
+ contributions to compute the weighted average of model parameters.
36
+
37
+ The protocol involves four main stages:
38
+ - 'setup': Send SecAgg configuration to clients and collect their public keys.
39
+ - 'share keys': Broadcast public keys among clients and collect encrypted secret
40
+ key shares.
41
+ - 'collect masked vectors': Forward encrypted secret key shares to target clients
42
+ and collect masked model parameters.
43
+ - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
44
+
45
+ Only the aggregated model parameters are exposed and passed to
46
+ `Strategy.aggregate_fit`, ensuring individual data privacy.
47
+
48
+ Parameters
49
+ ----------
50
+ reconstruction_threshold : Union[int, float]
51
+ The minimum number of shares required to reconstruct a client's private key,
52
+ or, if specified as a float, it represents the proportion of the total number
53
+ of shares needed for reconstruction. This threshold ensures privacy by allowing
54
+ for the recovery of contributions from dropped clients during aggregation,
55
+ without compromising individual client data.
56
+ max_weight : Optional[float] (default: 1000.0)
57
+ The maximum value of the weight that can be assigned to any single client's
58
+ update during the weighted average calculation on the server side, e.g., in the
59
+ FedAvg algorithm.
60
+ clipping_range : float, optional (default: 8.0)
61
+ The range within which model parameters are clipped before quantization.
62
+ This parameter ensures each model parameter is bounded within
63
+ [-clipping_range, clipping_range], facilitating quantization.
64
+ quantization_range : int, optional (default: 4194304, this equals 2**22)
65
+ The size of the range into which floating-point model parameters are quantized,
66
+ mapping each parameter to an integer in [0, quantization_range-1]. This
67
+ facilitates cryptographic operations on the model updates.
68
+ modulus_range : int, optional (default: 4294967296, this equals 2**32)
69
+ The range of values from which random mask entries are uniformly sampled
70
+ ([0, modulus_range-1]). `modulus_range` must be less than 4294967296.
71
+ Please use 2**n values for `modulus_range` to prevent overflow issues.
72
+ timeout : Optional[float] (default: None)
73
+ The timeout duration in seconds. If specified, the workflow will wait for
74
+ replies for this duration each time. If `None`, there is no time limit and
75
+ the workflow will wait until replies for all messages are received.
76
+
77
+ Notes
78
+ -----
79
+ - Each client's private key is split into N shares under the SecAgg protocol, where
80
+ N is the number of selected clients.
81
+ - Generally, higher `reconstruction_threshold` means better privacy guarantees but
82
+ less tolerance to dropouts.
83
+ - Too large `max_weight` may compromise the precision of the quantization.
84
+ - `modulus_range` must be 2**n and larger than `quantization_range`.
85
+ - When `reconstruction_threshold` is a float, it is interpreted as the proportion of
86
+ the number of all selected clients needed for the reconstruction of a private key.
87
+ This feature enables flexibility in setting the security threshold relative to the
88
+ number of selected clients.
89
+ - `reconstruction_threshold`, and the quantization parameters
90
+ (`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in
91
+ balancing privacy, robustness, and efficiency within the SecAgg protocol.
92
+ """
93
+
94
+ def __init__( # pylint: disable=R0913
95
+ self,
96
+ reconstruction_threshold: Union[int, float],
97
+ *,
98
+ max_weight: float = 1000.0,
99
+ clipping_range: float = 8.0,
100
+ quantization_range: int = 4194304,
101
+ modulus_range: int = 4294967296,
102
+ timeout: Optional[float] = None,
103
+ ) -> None:
104
+ super().__init__(
105
+ num_shares=1.0,
106
+ reconstruction_threshold=reconstruction_threshold,
107
+ max_weight=max_weight,
108
+ clipping_range=clipping_range,
109
+ quantization_range=quantization_range,
110
+ modulus_range=modulus_range,
111
+ timeout=timeout,
112
+ )
@@ -17,12 +17,11 @@
17
17
 
18
18
  import random
19
19
  from dataclasses import dataclass, field
20
- from logging import ERROR, WARN
21
- from typing import Dict, List, Optional, Set, Union, cast
20
+ from logging import DEBUG, ERROR, INFO, WARN
21
+ from typing import Dict, List, Optional, Set, Tuple, Union, cast
22
22
 
23
23
  import flwr.common.recordset_compat as compat
24
24
  from flwr.common import (
25
- Code,
26
25
  ConfigsRecord,
27
26
  Context,
28
27
  FitRes,
@@ -30,7 +29,6 @@ from flwr.common import (
30
29
  MessageType,
31
30
  NDArrays,
32
31
  RecordSet,
33
- Status,
34
32
  bytes_to_ndarray,
35
33
  log,
36
34
  ndarrays_to_parameters,
@@ -55,7 +53,7 @@ from flwr.common.secure_aggregation.secaggplus_constants import (
55
53
  Stage,
56
54
  )
57
55
  from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
58
- from flwr.server.compat.driver_client_proxy import DriverClientProxy
56
+ from flwr.server.client_proxy import ClientProxy
59
57
  from flwr.server.compat.legacy_context import LegacyContext
60
58
  from flwr.server.driver import Driver
61
59
 
@@ -67,6 +65,7 @@ from ..constant import Key as WorkflowKey
67
65
  class WorkflowState: # pylint: disable=R0902
68
66
  """The state of the SecAgg+ protocol."""
69
67
 
68
+ nid_to_proxies: Dict[int, ClientProxy] = field(default_factory=dict)
70
69
  nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict)
71
70
  sampled_node_ids: Set[int] = field(default_factory=set)
72
71
  active_node_ids: Set[int] = field(default_factory=set)
@@ -81,6 +80,7 @@ class WorkflowState: # pylint: disable=R0902
81
80
  forward_srcs: Dict[int, List[int]] = field(default_factory=dict)
82
81
  forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
83
82
  aggregate_ndarrays: NDArrays = field(default_factory=list)
83
+ legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
84
84
 
85
85
 
86
86
  class SecAggPlusWorkflow:
@@ -101,7 +101,7 @@ class SecAggPlusWorkflow:
101
101
  - 'setup': Send SecAgg+ configuration to clients and collect their public keys.
102
102
  - 'share keys': Broadcast public keys among clients and collect encrypted secret
103
103
  key shares.
104
- - 'collect masked inputs': Forward encrypted secret key shares to target clients
104
+ - 'collect masked vectors': Forward encrypted secret key shares to target clients
105
105
  and collect masked model parameters.
106
106
  - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
107
107
 
@@ -195,12 +195,15 @@ class SecAggPlusWorkflow:
195
195
  steps = (
196
196
  self.setup_stage,
197
197
  self.share_keys_stage,
198
- self.collect_masked_input_stage,
198
+ self.collect_masked_vectors_stage,
199
199
  self.unmask_stage,
200
200
  )
201
+ log(INFO, "Secure aggregation commencing.")
201
202
  for step in steps:
202
203
  if not step(driver, context, state):
204
+ log(INFO, "Secure aggregation halted.")
203
205
  return
206
+ log(INFO, "Secure aggregation completed.")
204
207
 
205
208
  def _check_init_params(self) -> None: # pylint: disable=R0912
206
209
  # Check `num_shares`
@@ -287,10 +290,21 @@ class SecAggPlusWorkflow:
287
290
  proxy_fitins_lst = context.strategy.configure_fit(
288
291
  current_round, parameters, context.client_manager
289
292
  )
293
+ if not proxy_fitins_lst:
294
+ log(INFO, "configure_fit: no clients selected, cancel")
295
+ return False
296
+ log(
297
+ INFO,
298
+ "configure_fit: strategy sampled %s clients (out of %s)",
299
+ len(proxy_fitins_lst),
300
+ context.client_manager.num_available(),
301
+ )
302
+
290
303
  state.nid_to_fitins = {
291
- proxy.node_id: compat.fitins_to_recordset(fitins, False)
304
+ proxy.node_id: compat.fitins_to_recordset(fitins, True)
292
305
  for proxy, fitins in proxy_fitins_lst
293
306
  }
307
+ state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst}
294
308
 
295
309
  # Protocol config
296
310
  sampled_node_ids = list(state.nid_to_fitins.keys())
@@ -362,12 +376,22 @@ class SecAggPlusWorkflow:
362
376
  ttl="",
363
377
  )
364
378
 
379
+ log(
380
+ DEBUG,
381
+ "[Stage 0] Sending configurations to %s clients.",
382
+ len(state.active_node_ids),
383
+ )
365
384
  msgs = driver.send_and_receive(
366
385
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
367
386
  )
368
387
  state.active_node_ids = {
369
388
  msg.metadata.src_node_id for msg in msgs if not msg.has_error()
370
389
  }
390
+ log(
391
+ DEBUG,
392
+ "[Stage 0] Received public keys from %s clients.",
393
+ len(state.active_node_ids),
394
+ )
371
395
 
372
396
  for msg in msgs:
373
397
  if msg.has_error():
@@ -401,12 +425,22 @@ class SecAggPlusWorkflow:
401
425
  )
402
426
 
403
427
  # Broadcast public keys to clients and receive secret key shares
428
+ log(
429
+ DEBUG,
430
+ "[Stage 1] Forwarding public keys to %s clients.",
431
+ len(state.active_node_ids),
432
+ )
404
433
  msgs = driver.send_and_receive(
405
434
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
406
435
  )
407
436
  state.active_node_ids = {
408
437
  msg.metadata.src_node_id for msg in msgs if not msg.has_error()
409
438
  }
439
+ log(
440
+ DEBUG,
441
+ "[Stage 1] Received encrypted key shares from %s clients.",
442
+ len(state.active_node_ids),
443
+ )
410
444
 
411
445
  # Build forward packet list dictionary
412
446
  srcs: List[int] = []
@@ -437,16 +471,16 @@ class SecAggPlusWorkflow:
437
471
 
438
472
  return self._check_threshold(state)
439
473
 
440
- def collect_masked_input_stage(
474
+ def collect_masked_vectors_stage(
441
475
  self, driver: Driver, context: LegacyContext, state: WorkflowState
442
476
  ) -> bool:
443
- """Execute the 'collect masked input' stage."""
477
+ """Execute the 'collect masked vectors' stage."""
444
478
  cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
445
479
 
446
- # Send secret key shares to clients (plus FitIns) and collect masked input
480
+ # Send secret key shares to clients (plus FitIns) and collect masked vectors
447
481
  def make(nid: int) -> Message:
448
482
  cfgs_dict = {
449
- Key.STAGE: Stage.COLLECT_MASKED_INPUT,
483
+ Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
450
484
  Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
451
485
  Key.SOURCE_LIST: state.forward_srcs[nid],
452
486
  }
@@ -461,12 +495,22 @@ class SecAggPlusWorkflow:
461
495
  ttl="",
462
496
  )
463
497
 
498
+ log(
499
+ DEBUG,
500
+ "[Stage 2] Forwarding encrypted key shares to %s clients.",
501
+ len(state.active_node_ids),
502
+ )
464
503
  msgs = driver.send_and_receive(
465
504
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
466
505
  )
467
506
  state.active_node_ids = {
468
507
  msg.metadata.src_node_id for msg in msgs if not msg.has_error()
469
508
  }
509
+ log(
510
+ DEBUG,
511
+ "[Stage 2] Received masked vectors from %s clients.",
512
+ len(state.active_node_ids),
513
+ )
470
514
 
471
515
  # Clear cache
472
516
  del state.forward_ciphertexts, state.forward_srcs, state.nid_to_fitins
@@ -485,9 +529,15 @@ class SecAggPlusWorkflow:
485
529
  masked_vector = parameters_mod(masked_vector, state.mod_range)
486
530
  state.aggregate_ndarrays = masked_vector
487
531
 
532
+ # Backward compatibility with Strategy
533
+ for msg in msgs:
534
+ fitres = compat.recordset_to_fitres(msg.content, True)
535
+ proxy = state.nid_to_proxies[msg.metadata.src_node_id]
536
+ state.legacy_results.append((proxy, fitres))
537
+
488
538
  return self._check_threshold(state)
489
539
 
490
- def unmask_stage( # pylint: disable=R0912, R0914
540
+ def unmask_stage( # pylint: disable=R0912, R0914, R0915
491
541
  self, driver: Driver, context: LegacyContext, state: WorkflowState
492
542
  ) -> bool:
493
543
  """Execute the 'unmask' stage."""
@@ -516,12 +566,22 @@ class SecAggPlusWorkflow:
516
566
  ttl="",
517
567
  )
518
568
 
569
+ log(
570
+ DEBUG,
571
+ "[Stage 3] Requesting key shares from %s clients to remove masks.",
572
+ len(state.active_node_ids),
573
+ )
519
574
  msgs = driver.send_and_receive(
520
575
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
521
576
  )
522
577
  state.active_node_ids = {
523
578
  msg.metadata.src_node_id for msg in msgs if not msg.has_error()
524
579
  }
580
+ log(
581
+ DEBUG,
582
+ "[Stage 3] Received key shares from %s clients.",
583
+ len(state.active_node_ids),
584
+ )
525
585
 
526
586
  # Build collected shares dict
527
587
  collected_shares_dict: Dict[int, List[bytes]] = {}
@@ -534,7 +594,7 @@ class SecAggPlusWorkflow:
534
594
  for owner_nid, share in zip(nids, shares):
535
595
  collected_shares_dict[owner_nid].append(share)
536
596
 
537
- # Remove mask for every client who is available after collect_masked_input stage
597
+ # Remove masks for every active client after collect_masked_vectors stage
538
598
  masked_vector = state.aggregate_ndarrays
539
599
  del state.aggregate_ndarrays
540
600
  for nid, share_list in collected_shares_dict.items():
@@ -584,18 +644,30 @@ class SecAggPlusWorkflow:
584
644
  for vec in aggregated_vector:
585
645
  vec += offset
586
646
  vec *= inv_dq_total_ratio
587
- state.aggregate_ndarrays = aggregated_vector
588
- final_fitres = FitRes(
589
- status=Status(code=Code.OK, message=""),
590
- parameters=ndarrays_to_parameters(aggregated_vector),
591
- num_examples=round(state.max_weight / inv_dq_total_ratio),
592
- metrics={},
593
- )
594
- empty_proxy = DriverClientProxy(
647
+
648
+ # Backward compatibility with Strategy
649
+ results = state.legacy_results
650
+ parameters = ndarrays_to_parameters(aggregated_vector)
651
+ for _, fitres in results:
652
+ fitres.parameters = parameters
653
+
654
+ # No exception/failure handling currently
655
+ log(
656
+ INFO,
657
+ "aggregate_fit: received %s results and %s failures",
658
+ len(results),
595
659
  0,
596
- driver.grpc_driver, # type: ignore
597
- False,
598
- driver.run_id, # type: ignore
599
660
  )
600
- context.strategy.aggregate_fit(current_round, [(empty_proxy, final_fitres)], [])
661
+ aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
662
+ parameters_aggregated, metrics_aggregated = aggregated_result
663
+
664
+ # Update the parameters and write history
665
+ if parameters_aggregated:
666
+ paramsrecord = compat.parameters_to_parametersrecord(
667
+ parameters_aggregated, True
668
+ )
669
+ context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
670
+ context.history.add_metrics_distributed_fit(
671
+ server_round=current_round, metrics=metrics_aggregated
672
+ )
601
673
  return True
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.8.0.dev20240310
3
+ Version: 1.8.0.dev20240312
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0