flwr-nightly 1.8.0.dev20240310__py3-none-any.whl → 1.8.0.dev20240312__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -15,8 +15,9 @@
15
15
  """Legacy default workflows."""
16
16
 
17
17
 
18
+ import io
18
19
  import timeit
19
- from logging import DEBUG, INFO
20
+ from logging import INFO
20
21
  from typing import Optional, cast
21
22
 
22
23
  import flwr.common.recordset_compat as compat
@@ -58,16 +59,18 @@ class DefaultWorkflow:
58
59
  )
59
60
 
60
61
  # Initialize parameters
62
+ log(INFO, "[INIT]")
61
63
  default_init_params_workflow(driver, context)
62
64
 
63
65
  # Run federated learning for num_rounds
64
- log(INFO, "FL starting")
65
66
  start_time = timeit.default_timer()
66
67
  cfg = ConfigsRecord()
67
68
  cfg[Key.START_TIME] = start_time
68
69
  context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg
69
70
 
70
71
  for current_round in range(1, context.config.num_rounds + 1):
72
+ log(INFO, "")
73
+ log(INFO, "[ROUND %s]", current_round)
71
74
  cfg[Key.CURRENT_ROUND] = current_round
72
75
 
73
76
  # Fit round
@@ -79,22 +82,19 @@ class DefaultWorkflow:
79
82
  # Evaluate round
80
83
  self.evaluate_workflow(driver, context)
81
84
 
82
- # Bookkeeping
85
+ # Bookkeeping and log results
83
86
  end_time = timeit.default_timer()
84
87
  elapsed = end_time - start_time
85
- log(INFO, "FL finished in %s", elapsed)
86
-
87
- # Log results
88
88
  hist = context.history
89
- log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed))
90
- log(
91
- INFO,
92
- "app_fit: metrics_distributed_fit %s",
93
- str(hist.metrics_distributed_fit),
94
- )
95
- log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed))
96
- log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized))
97
- log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized))
89
+ log(INFO, "")
90
+ log(INFO, "[SUMMARY]")
91
+ log(INFO, "Run finished %s rounds in %.2fs", context.config.num_rounds, elapsed)
92
+ for idx, line in enumerate(io.StringIO(str(hist))):
93
+ if idx == 0:
94
+ log(INFO, "%s", line.strip("\n"))
95
+ else:
96
+ log(INFO, "\t%s", line.strip("\n"))
97
+ log(INFO, "")
98
98
 
99
99
  # Terminate the thread
100
100
  f_stop.set()
@@ -107,12 +107,11 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
107
107
  if not isinstance(context, LegacyContext):
108
108
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
109
109
 
110
- log(INFO, "Initializing global parameters")
111
110
  parameters = context.strategy.initialize_parameters(
112
111
  client_manager=context.client_manager
113
112
  )
114
113
  if parameters is not None:
115
- log(INFO, "Using initial parameters provided by strategy")
114
+ log(INFO, "Using initial global parameters provided by strategy")
116
115
  paramsrecord = compat.parameters_to_parametersrecord(
117
116
  parameters, keep_input=True
118
117
  )
@@ -128,7 +127,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
128
127
  content=content,
129
128
  message_type=MessageTypeLegacy.GET_PARAMETERS,
130
129
  dst_node_id=random_client.node_id,
131
- group_id="",
130
+ group_id="0",
132
131
  ttl="",
133
132
  )
134
133
  ]
@@ -140,7 +139,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
140
139
  context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
141
140
 
142
141
  # Evaluate initial parameters
143
- log(INFO, "Evaluating initial parameters")
142
+ log(INFO, "Evaluating initial global parameters")
144
143
  parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
145
144
  res = context.strategy.evaluate(0, parameters=parameters)
146
145
  if res is not None:
@@ -186,7 +185,9 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
186
185
  )
187
186
 
188
187
 
189
- def default_fit_workflow(driver: Driver, context: Context) -> None:
188
+ def default_fit_workflow( # pylint: disable=R0914
189
+ driver: Driver, context: Context
190
+ ) -> None:
190
191
  """Execute the default workflow for a single fit round."""
191
192
  if not isinstance(context, LegacyContext):
192
193
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
@@ -207,12 +208,11 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
207
208
  )
208
209
 
209
210
  if not client_instructions:
210
- log(INFO, "fit_round %s: no clients selected, cancel", current_round)
211
+ log(INFO, "configure_fit: no clients selected, cancel")
211
212
  return
212
213
  log(
213
- DEBUG,
214
- "fit_round %s: strategy sampled %s clients (out of %s)",
215
- current_round,
214
+ INFO,
215
+ "configure_fit: strategy sampled %s clients (out of %s)",
216
216
  len(client_instructions),
217
217
  context.client_manager.num_available(),
218
218
  )
@@ -226,7 +226,7 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
226
226
  content=compat.fitins_to_recordset(fitins, True),
227
227
  message_type=MessageType.TRAIN,
228
228
  dst_node_id=proxy.node_id,
229
- group_id="",
229
+ group_id=str(current_round),
230
230
  ttl="",
231
231
  )
232
232
  for proxy, fitins in client_instructions
@@ -236,14 +236,14 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
236
236
  # collect `fit` results from all clients participating in this round
237
237
  messages = list(driver.send_and_receive(out_messages))
238
238
  del out_messages
239
+ num_failures = len([msg for msg in messages if msg.has_error()])
239
240
 
240
241
  # No exception/failure handling currently
241
242
  log(
242
- DEBUG,
243
- "fit_round %s received %s results and %s failures",
244
- current_round,
245
- len(messages),
246
- 0,
243
+ INFO,
244
+ "aggregate_fit: received %s results and %s failures",
245
+ len(messages) - num_failures,
246
+ num_failures,
247
247
  )
248
248
 
249
249
  # Aggregate training results
@@ -288,12 +288,11 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
288
288
  client_manager=context.client_manager,
289
289
  )
290
290
  if not client_instructions:
291
- log(INFO, "evaluate_round %s: no clients selected, cancel", current_round)
291
+ log(INFO, "configure_evaluate: no clients selected, skipping evaluation")
292
292
  return
293
293
  log(
294
- DEBUG,
295
- "evaluate_round %s: strategy sampled %s clients (out of %s)",
296
- current_round,
294
+ INFO,
295
+ "configure_evaluate: strategy sampled %s clients (out of %s)",
297
296
  len(client_instructions),
298
297
  context.client_manager.num_available(),
299
298
  )
@@ -307,7 +306,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
307
306
  content=compat.evaluateins_to_recordset(evalins, True),
308
307
  message_type=MessageType.EVALUATE,
309
308
  dst_node_id=proxy.node_id,
310
- group_id="",
309
+ group_id=str(current_round),
311
310
  ttl="",
312
311
  )
313
312
  for proxy, evalins in client_instructions
@@ -317,14 +316,14 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
317
316
  # collect `evaluate` results from all clients participating in this round
318
317
  messages = list(driver.send_and_receive(out_messages))
319
318
  del out_messages
319
+ num_failures = len([msg for msg in messages if msg.has_error()])
320
320
 
321
321
  # No exception/failure handling currently
322
322
  log(
323
- DEBUG,
324
- "evaluate_round %s received %s results and %s failures",
325
- current_round,
326
- len(messages),
327
- 0,
323
+ INFO,
324
+ "aggregate_evaluate: received %s results and %s failures",
325
+ len(messages) - num_failures,
326
+ num_failures,
328
327
  )
329
328
 
330
329
  # Aggregate the evaluation results
@@ -15,8 +15,10 @@
15
15
  """Secure Aggregation workflows."""
16
16
 
17
17
 
18
+ from .secagg_workflow import SecAggWorkflow
18
19
  from .secaggplus_workflow import SecAggPlusWorkflow
19
20
 
20
21
  __all__ = [
21
22
  "SecAggPlusWorkflow",
23
+ "SecAggWorkflow",
22
24
  ]
@@ -0,0 +1,112 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Workflow for the SecAgg protocol."""
16
+
17
+
18
+ from typing import Optional, Union
19
+
20
+ from .secaggplus_workflow import SecAggPlusWorkflow
21
+
22
+
23
+ class SecAggWorkflow(SecAggPlusWorkflow):
24
+ """The workflow for the SecAgg protocol.
25
+
26
+ The SecAgg protocol ensures the secure summation of integer vectors owned by
27
+ multiple parties, without accessing any individual integer vector. This workflow
28
+ allows the server to compute the weighted average of model parameters across all
29
+ clients, ensuring individual contributions remain private. This is achieved by
30
+ clients sending both, a weighting factor and a weighted version of the locally
31
+ updated parameters, both of which are masked for privacy. Specifically, each
32
+ client uploads "[w, w * params]" with masks, where weighting factor 'w' is the
33
+ number of examples ('num_examples') and 'params' represents the model parameters
34
+ ('parameters') from the client's `FitRes`. The server then aggregates these
35
+ contributions to compute the weighted average of model parameters.
36
+
37
+ The protocol involves four main stages:
38
+ - 'setup': Send SecAgg configuration to clients and collect their public keys.
39
+ - 'share keys': Broadcast public keys among clients and collect encrypted secret
40
+ key shares.
41
+ - 'collect masked vectors': Forward encrypted secret key shares to target clients
42
+ and collect masked model parameters.
43
+ - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
44
+
45
+ Only the aggregated model parameters are exposed and passed to
46
+ `Strategy.aggregate_fit`, ensuring individual data privacy.
47
+
48
+ Parameters
49
+ ----------
50
+ reconstruction_threshold : Union[int, float]
51
+ The minimum number of shares required to reconstruct a client's private key,
52
+ or, if specified as a float, it represents the proportion of the total number
53
+ of shares needed for reconstruction. This threshold ensures privacy by allowing
54
+ for the recovery of contributions from dropped clients during aggregation,
55
+ without compromising individual client data.
56
+ max_weight : Optional[float] (default: 1000.0)
57
+ The maximum value of the weight that can be assigned to any single client's
58
+ update during the weighted average calculation on the server side, e.g., in the
59
+ FedAvg algorithm.
60
+ clipping_range : float, optional (default: 8.0)
61
+ The range within which model parameters are clipped before quantization.
62
+ This parameter ensures each model parameter is bounded within
63
+ [-clipping_range, clipping_range], facilitating quantization.
64
+ quantization_range : int, optional (default: 4194304, this equals 2**22)
65
+ The size of the range into which floating-point model parameters are quantized,
66
+ mapping each parameter to an integer in [0, quantization_range-1]. This
67
+ facilitates cryptographic operations on the model updates.
68
+ modulus_range : int, optional (default: 4294967296, this equals 2**32)
69
+ The range of values from which random mask entries are uniformly sampled
70
+ ([0, modulus_range-1]). `modulus_range` must be less than 4294967296.
71
+ Please use 2**n values for `modulus_range` to prevent overflow issues.
72
+ timeout : Optional[float] (default: None)
73
+ The timeout duration in seconds. If specified, the workflow will wait for
74
+ replies for this duration each time. If `None`, there is no time limit and
75
+ the workflow will wait until replies for all messages are received.
76
+
77
+ Notes
78
+ -----
79
+ - Each client's private key is split into N shares under the SecAgg protocol, where
80
+ N is the number of selected clients.
81
+ - Generally, higher `reconstruction_threshold` means better privacy guarantees but
82
+ less tolerance to dropouts.
83
+ - Too large `max_weight` may compromise the precision of the quantization.
84
+ - `modulus_range` must be 2**n and larger than `quantization_range`.
85
+ - When `reconstruction_threshold` is a float, it is interpreted as the proportion of
86
+ the number of all selected clients needed for the reconstruction of a private key.
87
+ This feature enables flexibility in setting the security threshold relative to the
88
+ number of selected clients.
89
+ - `reconstruction_threshold`, and the quantization parameters
90
+ (`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in
91
+ balancing privacy, robustness, and efficiency within the SecAgg protocol.
92
+ """
93
+
94
+ def __init__( # pylint: disable=R0913
95
+ self,
96
+ reconstruction_threshold: Union[int, float],
97
+ *,
98
+ max_weight: float = 1000.0,
99
+ clipping_range: float = 8.0,
100
+ quantization_range: int = 4194304,
101
+ modulus_range: int = 4294967296,
102
+ timeout: Optional[float] = None,
103
+ ) -> None:
104
+ super().__init__(
105
+ num_shares=1.0,
106
+ reconstruction_threshold=reconstruction_threshold,
107
+ max_weight=max_weight,
108
+ clipping_range=clipping_range,
109
+ quantization_range=quantization_range,
110
+ modulus_range=modulus_range,
111
+ timeout=timeout,
112
+ )
@@ -17,12 +17,11 @@
17
17
 
18
18
  import random
19
19
  from dataclasses import dataclass, field
20
- from logging import ERROR, WARN
21
- from typing import Dict, List, Optional, Set, Union, cast
20
+ from logging import DEBUG, ERROR, INFO, WARN
21
+ from typing import Dict, List, Optional, Set, Tuple, Union, cast
22
22
 
23
23
  import flwr.common.recordset_compat as compat
24
24
  from flwr.common import (
25
- Code,
26
25
  ConfigsRecord,
27
26
  Context,
28
27
  FitRes,
@@ -30,7 +29,6 @@ from flwr.common import (
30
29
  MessageType,
31
30
  NDArrays,
32
31
  RecordSet,
33
- Status,
34
32
  bytes_to_ndarray,
35
33
  log,
36
34
  ndarrays_to_parameters,
@@ -55,7 +53,7 @@ from flwr.common.secure_aggregation.secaggplus_constants import (
55
53
  Stage,
56
54
  )
57
55
  from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
58
- from flwr.server.compat.driver_client_proxy import DriverClientProxy
56
+ from flwr.server.client_proxy import ClientProxy
59
57
  from flwr.server.compat.legacy_context import LegacyContext
60
58
  from flwr.server.driver import Driver
61
59
 
@@ -67,6 +65,7 @@ from ..constant import Key as WorkflowKey
67
65
  class WorkflowState: # pylint: disable=R0902
68
66
  """The state of the SecAgg+ protocol."""
69
67
 
68
+ nid_to_proxies: Dict[int, ClientProxy] = field(default_factory=dict)
70
69
  nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict)
71
70
  sampled_node_ids: Set[int] = field(default_factory=set)
72
71
  active_node_ids: Set[int] = field(default_factory=set)
@@ -81,6 +80,7 @@ class WorkflowState: # pylint: disable=R0902
81
80
  forward_srcs: Dict[int, List[int]] = field(default_factory=dict)
82
81
  forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
83
82
  aggregate_ndarrays: NDArrays = field(default_factory=list)
83
+ legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
84
84
 
85
85
 
86
86
  class SecAggPlusWorkflow:
@@ -101,7 +101,7 @@ class SecAggPlusWorkflow:
101
101
  - 'setup': Send SecAgg+ configuration to clients and collect their public keys.
102
102
  - 'share keys': Broadcast public keys among clients and collect encrypted secret
103
103
  key shares.
104
- - 'collect masked inputs': Forward encrypted secret key shares to target clients
104
+ - 'collect masked vectors': Forward encrypted secret key shares to target clients
105
105
  and collect masked model parameters.
106
106
  - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
107
107
 
@@ -195,12 +195,15 @@ class SecAggPlusWorkflow:
195
195
  steps = (
196
196
  self.setup_stage,
197
197
  self.share_keys_stage,
198
- self.collect_masked_input_stage,
198
+ self.collect_masked_vectors_stage,
199
199
  self.unmask_stage,
200
200
  )
201
+ log(INFO, "Secure aggregation commencing.")
201
202
  for step in steps:
202
203
  if not step(driver, context, state):
204
+ log(INFO, "Secure aggregation halted.")
203
205
  return
206
+ log(INFO, "Secure aggregation completed.")
204
207
 
205
208
  def _check_init_params(self) -> None: # pylint: disable=R0912
206
209
  # Check `num_shares`
@@ -287,10 +290,21 @@ class SecAggPlusWorkflow:
287
290
  proxy_fitins_lst = context.strategy.configure_fit(
288
291
  current_round, parameters, context.client_manager
289
292
  )
293
+ if not proxy_fitins_lst:
294
+ log(INFO, "configure_fit: no clients selected, cancel")
295
+ return False
296
+ log(
297
+ INFO,
298
+ "configure_fit: strategy sampled %s clients (out of %s)",
299
+ len(proxy_fitins_lst),
300
+ context.client_manager.num_available(),
301
+ )
302
+
290
303
  state.nid_to_fitins = {
291
- proxy.node_id: compat.fitins_to_recordset(fitins, False)
304
+ proxy.node_id: compat.fitins_to_recordset(fitins, True)
292
305
  for proxy, fitins in proxy_fitins_lst
293
306
  }
307
+ state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst}
294
308
 
295
309
  # Protocol config
296
310
  sampled_node_ids = list(state.nid_to_fitins.keys())
@@ -362,12 +376,22 @@ class SecAggPlusWorkflow:
362
376
  ttl="",
363
377
  )
364
378
 
379
+ log(
380
+ DEBUG,
381
+ "[Stage 0] Sending configurations to %s clients.",
382
+ len(state.active_node_ids),
383
+ )
365
384
  msgs = driver.send_and_receive(
366
385
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
367
386
  )
368
387
  state.active_node_ids = {
369
388
  msg.metadata.src_node_id for msg in msgs if not msg.has_error()
370
389
  }
390
+ log(
391
+ DEBUG,
392
+ "[Stage 0] Received public keys from %s clients.",
393
+ len(state.active_node_ids),
394
+ )
371
395
 
372
396
  for msg in msgs:
373
397
  if msg.has_error():
@@ -401,12 +425,22 @@ class SecAggPlusWorkflow:
401
425
  )
402
426
 
403
427
  # Broadcast public keys to clients and receive secret key shares
428
+ log(
429
+ DEBUG,
430
+ "[Stage 1] Forwarding public keys to %s clients.",
431
+ len(state.active_node_ids),
432
+ )
404
433
  msgs = driver.send_and_receive(
405
434
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
406
435
  )
407
436
  state.active_node_ids = {
408
437
  msg.metadata.src_node_id for msg in msgs if not msg.has_error()
409
438
  }
439
+ log(
440
+ DEBUG,
441
+ "[Stage 1] Received encrypted key shares from %s clients.",
442
+ len(state.active_node_ids),
443
+ )
410
444
 
411
445
  # Build forward packet list dictionary
412
446
  srcs: List[int] = []
@@ -437,16 +471,16 @@ class SecAggPlusWorkflow:
437
471
 
438
472
  return self._check_threshold(state)
439
473
 
440
- def collect_masked_input_stage(
474
+ def collect_masked_vectors_stage(
441
475
  self, driver: Driver, context: LegacyContext, state: WorkflowState
442
476
  ) -> bool:
443
- """Execute the 'collect masked input' stage."""
477
+ """Execute the 'collect masked vectors' stage."""
444
478
  cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
445
479
 
446
- # Send secret key shares to clients (plus FitIns) and collect masked input
480
+ # Send secret key shares to clients (plus FitIns) and collect masked vectors
447
481
  def make(nid: int) -> Message:
448
482
  cfgs_dict = {
449
- Key.STAGE: Stage.COLLECT_MASKED_INPUT,
483
+ Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
450
484
  Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
451
485
  Key.SOURCE_LIST: state.forward_srcs[nid],
452
486
  }
@@ -461,12 +495,22 @@ class SecAggPlusWorkflow:
461
495
  ttl="",
462
496
  )
463
497
 
498
+ log(
499
+ DEBUG,
500
+ "[Stage 2] Forwarding encrypted key shares to %s clients.",
501
+ len(state.active_node_ids),
502
+ )
464
503
  msgs = driver.send_and_receive(
465
504
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
466
505
  )
467
506
  state.active_node_ids = {
468
507
  msg.metadata.src_node_id for msg in msgs if not msg.has_error()
469
508
  }
509
+ log(
510
+ DEBUG,
511
+ "[Stage 2] Received masked vectors from %s clients.",
512
+ len(state.active_node_ids),
513
+ )
470
514
 
471
515
  # Clear cache
472
516
  del state.forward_ciphertexts, state.forward_srcs, state.nid_to_fitins
@@ -485,9 +529,15 @@ class SecAggPlusWorkflow:
485
529
  masked_vector = parameters_mod(masked_vector, state.mod_range)
486
530
  state.aggregate_ndarrays = masked_vector
487
531
 
532
+ # Backward compatibility with Strategy
533
+ for msg in msgs:
534
+ fitres = compat.recordset_to_fitres(msg.content, True)
535
+ proxy = state.nid_to_proxies[msg.metadata.src_node_id]
536
+ state.legacy_results.append((proxy, fitres))
537
+
488
538
  return self._check_threshold(state)
489
539
 
490
- def unmask_stage( # pylint: disable=R0912, R0914
540
+ def unmask_stage( # pylint: disable=R0912, R0914, R0915
491
541
  self, driver: Driver, context: LegacyContext, state: WorkflowState
492
542
  ) -> bool:
493
543
  """Execute the 'unmask' stage."""
@@ -516,12 +566,22 @@ class SecAggPlusWorkflow:
516
566
  ttl="",
517
567
  )
518
568
 
569
+ log(
570
+ DEBUG,
571
+ "[Stage 3] Requesting key shares from %s clients to remove masks.",
572
+ len(state.active_node_ids),
573
+ )
519
574
  msgs = driver.send_and_receive(
520
575
  [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
521
576
  )
522
577
  state.active_node_ids = {
523
578
  msg.metadata.src_node_id for msg in msgs if not msg.has_error()
524
579
  }
580
+ log(
581
+ DEBUG,
582
+ "[Stage 3] Received key shares from %s clients.",
583
+ len(state.active_node_ids),
584
+ )
525
585
 
526
586
  # Build collected shares dict
527
587
  collected_shares_dict: Dict[int, List[bytes]] = {}
@@ -534,7 +594,7 @@ class SecAggPlusWorkflow:
534
594
  for owner_nid, share in zip(nids, shares):
535
595
  collected_shares_dict[owner_nid].append(share)
536
596
 
537
- # Remove mask for every client who is available after collect_masked_input stage
597
+ # Remove masks for every active client after collect_masked_vectors stage
538
598
  masked_vector = state.aggregate_ndarrays
539
599
  del state.aggregate_ndarrays
540
600
  for nid, share_list in collected_shares_dict.items():
@@ -584,18 +644,30 @@ class SecAggPlusWorkflow:
584
644
  for vec in aggregated_vector:
585
645
  vec += offset
586
646
  vec *= inv_dq_total_ratio
587
- state.aggregate_ndarrays = aggregated_vector
588
- final_fitres = FitRes(
589
- status=Status(code=Code.OK, message=""),
590
- parameters=ndarrays_to_parameters(aggregated_vector),
591
- num_examples=round(state.max_weight / inv_dq_total_ratio),
592
- metrics={},
593
- )
594
- empty_proxy = DriverClientProxy(
647
+
648
+ # Backward compatibility with Strategy
649
+ results = state.legacy_results
650
+ parameters = ndarrays_to_parameters(aggregated_vector)
651
+ for _, fitres in results:
652
+ fitres.parameters = parameters
653
+
654
+ # No exception/failure handling currently
655
+ log(
656
+ INFO,
657
+ "aggregate_fit: received %s results and %s failures",
658
+ len(results),
595
659
  0,
596
- driver.grpc_driver, # type: ignore
597
- False,
598
- driver.run_id, # type: ignore
599
660
  )
600
- context.strategy.aggregate_fit(current_round, [(empty_proxy, final_fitres)], [])
661
+ aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
662
+ parameters_aggregated, metrics_aggregated = aggregated_result
663
+
664
+ # Update the parameters and write history
665
+ if parameters_aggregated:
666
+ paramsrecord = compat.parameters_to_parametersrecord(
667
+ parameters_aggregated, True
668
+ )
669
+ context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
670
+ context.history.add_metrics_distributed_fit(
671
+ server_round=current_round, metrics=metrics_aggregated
672
+ )
601
673
  return True
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.8.0.dev20240310
3
+ Version: 1.8.0.dev20240312
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.ai
6
6
  License: Apache-2.0