flwr-nightly 1.8.0.dev20240310__py3-none-any.whl → 1.8.0.dev20240312__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/new.py +6 -3
- flwr/cli/utils.py +14 -1
- flwr/client/app.py +25 -2
- flwr/client/mod/__init__.py +2 -1
- flwr/client/mod/secure_aggregation/__init__.py +2 -0
- flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +42 -51
- flwr/common/logger.py +6 -8
- flwr/common/pyproject.py +41 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +2 -2
- flwr/server/superlink/state/in_memory_state.py +34 -32
- flwr/server/workflow/__init__.py +2 -1
- flwr/server/workflow/default_workflows.py +39 -40
- flwr/server/workflow/secure_aggregation/__init__.py +2 -0
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +98 -26
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/RECORD +21 -18
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/entry_points.txt +0 -0
@@ -15,8 +15,9 @@
|
|
15
15
|
"""Legacy default workflows."""
|
16
16
|
|
17
17
|
|
18
|
+
import io
|
18
19
|
import timeit
|
19
|
-
from logging import
|
20
|
+
from logging import INFO
|
20
21
|
from typing import Optional, cast
|
21
22
|
|
22
23
|
import flwr.common.recordset_compat as compat
|
@@ -58,16 +59,18 @@ class DefaultWorkflow:
|
|
58
59
|
)
|
59
60
|
|
60
61
|
# Initialize parameters
|
62
|
+
log(INFO, "[INIT]")
|
61
63
|
default_init_params_workflow(driver, context)
|
62
64
|
|
63
65
|
# Run federated learning for num_rounds
|
64
|
-
log(INFO, "FL starting")
|
65
66
|
start_time = timeit.default_timer()
|
66
67
|
cfg = ConfigsRecord()
|
67
68
|
cfg[Key.START_TIME] = start_time
|
68
69
|
context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg
|
69
70
|
|
70
71
|
for current_round in range(1, context.config.num_rounds + 1):
|
72
|
+
log(INFO, "")
|
73
|
+
log(INFO, "[ROUND %s]", current_round)
|
71
74
|
cfg[Key.CURRENT_ROUND] = current_round
|
72
75
|
|
73
76
|
# Fit round
|
@@ -79,22 +82,19 @@ class DefaultWorkflow:
|
|
79
82
|
# Evaluate round
|
80
83
|
self.evaluate_workflow(driver, context)
|
81
84
|
|
82
|
-
# Bookkeeping
|
85
|
+
# Bookkeeping and log results
|
83
86
|
end_time = timeit.default_timer()
|
84
87
|
elapsed = end_time - start_time
|
85
|
-
log(INFO, "FL finished in %s", elapsed)
|
86
|
-
|
87
|
-
# Log results
|
88
88
|
hist = context.history
|
89
|
-
log(INFO, "
|
90
|
-
log(
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
log(INFO, "
|
89
|
+
log(INFO, "")
|
90
|
+
log(INFO, "[SUMMARY]")
|
91
|
+
log(INFO, "Run finished %s rounds in %.2fs", context.config.num_rounds, elapsed)
|
92
|
+
for idx, line in enumerate(io.StringIO(str(hist))):
|
93
|
+
if idx == 0:
|
94
|
+
log(INFO, "%s", line.strip("\n"))
|
95
|
+
else:
|
96
|
+
log(INFO, "\t%s", line.strip("\n"))
|
97
|
+
log(INFO, "")
|
98
98
|
|
99
99
|
# Terminate the thread
|
100
100
|
f_stop.set()
|
@@ -107,12 +107,11 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
107
107
|
if not isinstance(context, LegacyContext):
|
108
108
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
109
109
|
|
110
|
-
log(INFO, "Initializing global parameters")
|
111
110
|
parameters = context.strategy.initialize_parameters(
|
112
111
|
client_manager=context.client_manager
|
113
112
|
)
|
114
113
|
if parameters is not None:
|
115
|
-
log(INFO, "Using initial parameters provided by strategy")
|
114
|
+
log(INFO, "Using initial global parameters provided by strategy")
|
116
115
|
paramsrecord = compat.parameters_to_parametersrecord(
|
117
116
|
parameters, keep_input=True
|
118
117
|
)
|
@@ -128,7 +127,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
128
127
|
content=content,
|
129
128
|
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
130
129
|
dst_node_id=random_client.node_id,
|
131
|
-
group_id="",
|
130
|
+
group_id="0",
|
132
131
|
ttl="",
|
133
132
|
)
|
134
133
|
]
|
@@ -140,7 +139,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
140
139
|
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
141
140
|
|
142
141
|
# Evaluate initial parameters
|
143
|
-
log(INFO, "Evaluating initial parameters")
|
142
|
+
log(INFO, "Evaluating initial global parameters")
|
144
143
|
parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
|
145
144
|
res = context.strategy.evaluate(0, parameters=parameters)
|
146
145
|
if res is not None:
|
@@ -186,7 +185,9 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
|
|
186
185
|
)
|
187
186
|
|
188
187
|
|
189
|
-
def default_fit_workflow(
|
188
|
+
def default_fit_workflow( # pylint: disable=R0914
|
189
|
+
driver: Driver, context: Context
|
190
|
+
) -> None:
|
190
191
|
"""Execute the default workflow for a single fit round."""
|
191
192
|
if not isinstance(context, LegacyContext):
|
192
193
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
@@ -207,12 +208,11 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
|
|
207
208
|
)
|
208
209
|
|
209
210
|
if not client_instructions:
|
210
|
-
log(INFO, "
|
211
|
+
log(INFO, "configure_fit: no clients selected, cancel")
|
211
212
|
return
|
212
213
|
log(
|
213
|
-
|
214
|
-
"
|
215
|
-
current_round,
|
214
|
+
INFO,
|
215
|
+
"configure_fit: strategy sampled %s clients (out of %s)",
|
216
216
|
len(client_instructions),
|
217
217
|
context.client_manager.num_available(),
|
218
218
|
)
|
@@ -226,7 +226,7 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
|
|
226
226
|
content=compat.fitins_to_recordset(fitins, True),
|
227
227
|
message_type=MessageType.TRAIN,
|
228
228
|
dst_node_id=proxy.node_id,
|
229
|
-
group_id=
|
229
|
+
group_id=str(current_round),
|
230
230
|
ttl="",
|
231
231
|
)
|
232
232
|
for proxy, fitins in client_instructions
|
@@ -236,14 +236,14 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
|
|
236
236
|
# collect `fit` results from all clients participating in this round
|
237
237
|
messages = list(driver.send_and_receive(out_messages))
|
238
238
|
del out_messages
|
239
|
+
num_failures = len([msg for msg in messages if msg.has_error()])
|
239
240
|
|
240
241
|
# No exception/failure handling currently
|
241
242
|
log(
|
242
|
-
|
243
|
-
"
|
244
|
-
|
245
|
-
|
246
|
-
0,
|
243
|
+
INFO,
|
244
|
+
"aggregate_fit: received %s results and %s failures",
|
245
|
+
len(messages) - num_failures,
|
246
|
+
num_failures,
|
247
247
|
)
|
248
248
|
|
249
249
|
# Aggregate training results
|
@@ -288,12 +288,11 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
288
288
|
client_manager=context.client_manager,
|
289
289
|
)
|
290
290
|
if not client_instructions:
|
291
|
-
log(INFO, "
|
291
|
+
log(INFO, "configure_evaluate: no clients selected, skipping evaluation")
|
292
292
|
return
|
293
293
|
log(
|
294
|
-
|
295
|
-
"
|
296
|
-
current_round,
|
294
|
+
INFO,
|
295
|
+
"configure_evaluate: strategy sampled %s clients (out of %s)",
|
297
296
|
len(client_instructions),
|
298
297
|
context.client_manager.num_available(),
|
299
298
|
)
|
@@ -307,7 +306,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
307
306
|
content=compat.evaluateins_to_recordset(evalins, True),
|
308
307
|
message_type=MessageType.EVALUATE,
|
309
308
|
dst_node_id=proxy.node_id,
|
310
|
-
group_id=
|
309
|
+
group_id=str(current_round),
|
311
310
|
ttl="",
|
312
311
|
)
|
313
312
|
for proxy, evalins in client_instructions
|
@@ -317,14 +316,14 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
317
316
|
# collect `evaluate` results from all clients participating in this round
|
318
317
|
messages = list(driver.send_and_receive(out_messages))
|
319
318
|
del out_messages
|
319
|
+
num_failures = len([msg for msg in messages if msg.has_error()])
|
320
320
|
|
321
321
|
# No exception/failure handling currently
|
322
322
|
log(
|
323
|
-
|
324
|
-
"
|
325
|
-
|
326
|
-
|
327
|
-
0,
|
323
|
+
INFO,
|
324
|
+
"aggregate_evaluate: received %s results and %s failures",
|
325
|
+
len(messages) - num_failures,
|
326
|
+
num_failures,
|
328
327
|
)
|
329
328
|
|
330
329
|
# Aggregate the evaluation results
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Workflow for the SecAgg protocol."""
|
16
|
+
|
17
|
+
|
18
|
+
from typing import Optional, Union
|
19
|
+
|
20
|
+
from .secaggplus_workflow import SecAggPlusWorkflow
|
21
|
+
|
22
|
+
|
23
|
+
class SecAggWorkflow(SecAggPlusWorkflow):
|
24
|
+
"""The workflow for the SecAgg protocol.
|
25
|
+
|
26
|
+
The SecAgg protocol ensures the secure summation of integer vectors owned by
|
27
|
+
multiple parties, without accessing any individual integer vector. This workflow
|
28
|
+
allows the server to compute the weighted average of model parameters across all
|
29
|
+
clients, ensuring individual contributions remain private. This is achieved by
|
30
|
+
clients sending both, a weighting factor and a weighted version of the locally
|
31
|
+
updated parameters, both of which are masked for privacy. Specifically, each
|
32
|
+
client uploads "[w, w * params]" with masks, where weighting factor 'w' is the
|
33
|
+
number of examples ('num_examples') and 'params' represents the model parameters
|
34
|
+
('parameters') from the client's `FitRes`. The server then aggregates these
|
35
|
+
contributions to compute the weighted average of model parameters.
|
36
|
+
|
37
|
+
The protocol involves four main stages:
|
38
|
+
- 'setup': Send SecAgg configuration to clients and collect their public keys.
|
39
|
+
- 'share keys': Broadcast public keys among clients and collect encrypted secret
|
40
|
+
key shares.
|
41
|
+
- 'collect masked vectors': Forward encrypted secret key shares to target clients
|
42
|
+
and collect masked model parameters.
|
43
|
+
- 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
|
44
|
+
|
45
|
+
Only the aggregated model parameters are exposed and passed to
|
46
|
+
`Strategy.aggregate_fit`, ensuring individual data privacy.
|
47
|
+
|
48
|
+
Parameters
|
49
|
+
----------
|
50
|
+
reconstruction_threshold : Union[int, float]
|
51
|
+
The minimum number of shares required to reconstruct a client's private key,
|
52
|
+
or, if specified as a float, it represents the proportion of the total number
|
53
|
+
of shares needed for reconstruction. This threshold ensures privacy by allowing
|
54
|
+
for the recovery of contributions from dropped clients during aggregation,
|
55
|
+
without compromising individual client data.
|
56
|
+
max_weight : Optional[float] (default: 1000.0)
|
57
|
+
The maximum value of the weight that can be assigned to any single client's
|
58
|
+
update during the weighted average calculation on the server side, e.g., in the
|
59
|
+
FedAvg algorithm.
|
60
|
+
clipping_range : float, optional (default: 8.0)
|
61
|
+
The range within which model parameters are clipped before quantization.
|
62
|
+
This parameter ensures each model parameter is bounded within
|
63
|
+
[-clipping_range, clipping_range], facilitating quantization.
|
64
|
+
quantization_range : int, optional (default: 4194304, this equals 2**22)
|
65
|
+
The size of the range into which floating-point model parameters are quantized,
|
66
|
+
mapping each parameter to an integer in [0, quantization_range-1]. This
|
67
|
+
facilitates cryptographic operations on the model updates.
|
68
|
+
modulus_range : int, optional (default: 4294967296, this equals 2**32)
|
69
|
+
The range of values from which random mask entries are uniformly sampled
|
70
|
+
([0, modulus_range-1]). `modulus_range` must be less than 4294967296.
|
71
|
+
Please use 2**n values for `modulus_range` to prevent overflow issues.
|
72
|
+
timeout : Optional[float] (default: None)
|
73
|
+
The timeout duration in seconds. If specified, the workflow will wait for
|
74
|
+
replies for this duration each time. If `None`, there is no time limit and
|
75
|
+
the workflow will wait until replies for all messages are received.
|
76
|
+
|
77
|
+
Notes
|
78
|
+
-----
|
79
|
+
- Each client's private key is split into N shares under the SecAgg protocol, where
|
80
|
+
N is the number of selected clients.
|
81
|
+
- Generally, higher `reconstruction_threshold` means better privacy guarantees but
|
82
|
+
less tolerance to dropouts.
|
83
|
+
- Too large `max_weight` may compromise the precision of the quantization.
|
84
|
+
- `modulus_range` must be 2**n and larger than `quantization_range`.
|
85
|
+
- When `reconstruction_threshold` is a float, it is interpreted as the proportion of
|
86
|
+
the number of all selected clients needed for the reconstruction of a private key.
|
87
|
+
This feature enables flexibility in setting the security threshold relative to the
|
88
|
+
number of selected clients.
|
89
|
+
- `reconstruction_threshold`, and the quantization parameters
|
90
|
+
(`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in
|
91
|
+
balancing privacy, robustness, and efficiency within the SecAgg protocol.
|
92
|
+
"""
|
93
|
+
|
94
|
+
def __init__( # pylint: disable=R0913
|
95
|
+
self,
|
96
|
+
reconstruction_threshold: Union[int, float],
|
97
|
+
*,
|
98
|
+
max_weight: float = 1000.0,
|
99
|
+
clipping_range: float = 8.0,
|
100
|
+
quantization_range: int = 4194304,
|
101
|
+
modulus_range: int = 4294967296,
|
102
|
+
timeout: Optional[float] = None,
|
103
|
+
) -> None:
|
104
|
+
super().__init__(
|
105
|
+
num_shares=1.0,
|
106
|
+
reconstruction_threshold=reconstruction_threshold,
|
107
|
+
max_weight=max_weight,
|
108
|
+
clipping_range=clipping_range,
|
109
|
+
quantization_range=quantization_range,
|
110
|
+
modulus_range=modulus_range,
|
111
|
+
timeout=timeout,
|
112
|
+
)
|
@@ -17,12 +17,11 @@
|
|
17
17
|
|
18
18
|
import random
|
19
19
|
from dataclasses import dataclass, field
|
20
|
-
from logging import ERROR, WARN
|
21
|
-
from typing import Dict, List, Optional, Set, Union, cast
|
20
|
+
from logging import DEBUG, ERROR, INFO, WARN
|
21
|
+
from typing import Dict, List, Optional, Set, Tuple, Union, cast
|
22
22
|
|
23
23
|
import flwr.common.recordset_compat as compat
|
24
24
|
from flwr.common import (
|
25
|
-
Code,
|
26
25
|
ConfigsRecord,
|
27
26
|
Context,
|
28
27
|
FitRes,
|
@@ -30,7 +29,6 @@ from flwr.common import (
|
|
30
29
|
MessageType,
|
31
30
|
NDArrays,
|
32
31
|
RecordSet,
|
33
|
-
Status,
|
34
32
|
bytes_to_ndarray,
|
35
33
|
log,
|
36
34
|
ndarrays_to_parameters,
|
@@ -55,7 +53,7 @@ from flwr.common.secure_aggregation.secaggplus_constants import (
|
|
55
53
|
Stage,
|
56
54
|
)
|
57
55
|
from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
|
58
|
-
from flwr.server.
|
56
|
+
from flwr.server.client_proxy import ClientProxy
|
59
57
|
from flwr.server.compat.legacy_context import LegacyContext
|
60
58
|
from flwr.server.driver import Driver
|
61
59
|
|
@@ -67,6 +65,7 @@ from ..constant import Key as WorkflowKey
|
|
67
65
|
class WorkflowState: # pylint: disable=R0902
|
68
66
|
"""The state of the SecAgg+ protocol."""
|
69
67
|
|
68
|
+
nid_to_proxies: Dict[int, ClientProxy] = field(default_factory=dict)
|
70
69
|
nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict)
|
71
70
|
sampled_node_ids: Set[int] = field(default_factory=set)
|
72
71
|
active_node_ids: Set[int] = field(default_factory=set)
|
@@ -81,6 +80,7 @@ class WorkflowState: # pylint: disable=R0902
|
|
81
80
|
forward_srcs: Dict[int, List[int]] = field(default_factory=dict)
|
82
81
|
forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
|
83
82
|
aggregate_ndarrays: NDArrays = field(default_factory=list)
|
83
|
+
legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
|
84
84
|
|
85
85
|
|
86
86
|
class SecAggPlusWorkflow:
|
@@ -101,7 +101,7 @@ class SecAggPlusWorkflow:
|
|
101
101
|
- 'setup': Send SecAgg+ configuration to clients and collect their public keys.
|
102
102
|
- 'share keys': Broadcast public keys among clients and collect encrypted secret
|
103
103
|
key shares.
|
104
|
-
- 'collect masked
|
104
|
+
- 'collect masked vectors': Forward encrypted secret key shares to target clients
|
105
105
|
and collect masked model parameters.
|
106
106
|
- 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
|
107
107
|
|
@@ -195,12 +195,15 @@ class SecAggPlusWorkflow:
|
|
195
195
|
steps = (
|
196
196
|
self.setup_stage,
|
197
197
|
self.share_keys_stage,
|
198
|
-
self.
|
198
|
+
self.collect_masked_vectors_stage,
|
199
199
|
self.unmask_stage,
|
200
200
|
)
|
201
|
+
log(INFO, "Secure aggregation commencing.")
|
201
202
|
for step in steps:
|
202
203
|
if not step(driver, context, state):
|
204
|
+
log(INFO, "Secure aggregation halted.")
|
203
205
|
return
|
206
|
+
log(INFO, "Secure aggregation completed.")
|
204
207
|
|
205
208
|
def _check_init_params(self) -> None: # pylint: disable=R0912
|
206
209
|
# Check `num_shares`
|
@@ -287,10 +290,21 @@ class SecAggPlusWorkflow:
|
|
287
290
|
proxy_fitins_lst = context.strategy.configure_fit(
|
288
291
|
current_round, parameters, context.client_manager
|
289
292
|
)
|
293
|
+
if not proxy_fitins_lst:
|
294
|
+
log(INFO, "configure_fit: no clients selected, cancel")
|
295
|
+
return False
|
296
|
+
log(
|
297
|
+
INFO,
|
298
|
+
"configure_fit: strategy sampled %s clients (out of %s)",
|
299
|
+
len(proxy_fitins_lst),
|
300
|
+
context.client_manager.num_available(),
|
301
|
+
)
|
302
|
+
|
290
303
|
state.nid_to_fitins = {
|
291
|
-
proxy.node_id: compat.fitins_to_recordset(fitins,
|
304
|
+
proxy.node_id: compat.fitins_to_recordset(fitins, True)
|
292
305
|
for proxy, fitins in proxy_fitins_lst
|
293
306
|
}
|
307
|
+
state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst}
|
294
308
|
|
295
309
|
# Protocol config
|
296
310
|
sampled_node_ids = list(state.nid_to_fitins.keys())
|
@@ -362,12 +376,22 @@ class SecAggPlusWorkflow:
|
|
362
376
|
ttl="",
|
363
377
|
)
|
364
378
|
|
379
|
+
log(
|
380
|
+
DEBUG,
|
381
|
+
"[Stage 0] Sending configurations to %s clients.",
|
382
|
+
len(state.active_node_ids),
|
383
|
+
)
|
365
384
|
msgs = driver.send_and_receive(
|
366
385
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
367
386
|
)
|
368
387
|
state.active_node_ids = {
|
369
388
|
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
370
389
|
}
|
390
|
+
log(
|
391
|
+
DEBUG,
|
392
|
+
"[Stage 0] Received public keys from %s clients.",
|
393
|
+
len(state.active_node_ids),
|
394
|
+
)
|
371
395
|
|
372
396
|
for msg in msgs:
|
373
397
|
if msg.has_error():
|
@@ -401,12 +425,22 @@ class SecAggPlusWorkflow:
|
|
401
425
|
)
|
402
426
|
|
403
427
|
# Broadcast public keys to clients and receive secret key shares
|
428
|
+
log(
|
429
|
+
DEBUG,
|
430
|
+
"[Stage 1] Forwarding public keys to %s clients.",
|
431
|
+
len(state.active_node_ids),
|
432
|
+
)
|
404
433
|
msgs = driver.send_and_receive(
|
405
434
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
406
435
|
)
|
407
436
|
state.active_node_ids = {
|
408
437
|
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
409
438
|
}
|
439
|
+
log(
|
440
|
+
DEBUG,
|
441
|
+
"[Stage 1] Received encrypted key shares from %s clients.",
|
442
|
+
len(state.active_node_ids),
|
443
|
+
)
|
410
444
|
|
411
445
|
# Build forward packet list dictionary
|
412
446
|
srcs: List[int] = []
|
@@ -437,16 +471,16 @@ class SecAggPlusWorkflow:
|
|
437
471
|
|
438
472
|
return self._check_threshold(state)
|
439
473
|
|
440
|
-
def
|
474
|
+
def collect_masked_vectors_stage(
|
441
475
|
self, driver: Driver, context: LegacyContext, state: WorkflowState
|
442
476
|
) -> bool:
|
443
|
-
"""Execute the 'collect masked
|
477
|
+
"""Execute the 'collect masked vectors' stage."""
|
444
478
|
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
445
479
|
|
446
|
-
# Send secret key shares to clients (plus FitIns) and collect masked
|
480
|
+
# Send secret key shares to clients (plus FitIns) and collect masked vectors
|
447
481
|
def make(nid: int) -> Message:
|
448
482
|
cfgs_dict = {
|
449
|
-
Key.STAGE: Stage.
|
483
|
+
Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
|
450
484
|
Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
|
451
485
|
Key.SOURCE_LIST: state.forward_srcs[nid],
|
452
486
|
}
|
@@ -461,12 +495,22 @@ class SecAggPlusWorkflow:
|
|
461
495
|
ttl="",
|
462
496
|
)
|
463
497
|
|
498
|
+
log(
|
499
|
+
DEBUG,
|
500
|
+
"[Stage 2] Forwarding encrypted key shares to %s clients.",
|
501
|
+
len(state.active_node_ids),
|
502
|
+
)
|
464
503
|
msgs = driver.send_and_receive(
|
465
504
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
466
505
|
)
|
467
506
|
state.active_node_ids = {
|
468
507
|
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
469
508
|
}
|
509
|
+
log(
|
510
|
+
DEBUG,
|
511
|
+
"[Stage 2] Received masked vectors from %s clients.",
|
512
|
+
len(state.active_node_ids),
|
513
|
+
)
|
470
514
|
|
471
515
|
# Clear cache
|
472
516
|
del state.forward_ciphertexts, state.forward_srcs, state.nid_to_fitins
|
@@ -485,9 +529,15 @@ class SecAggPlusWorkflow:
|
|
485
529
|
masked_vector = parameters_mod(masked_vector, state.mod_range)
|
486
530
|
state.aggregate_ndarrays = masked_vector
|
487
531
|
|
532
|
+
# Backward compatibility with Strategy
|
533
|
+
for msg in msgs:
|
534
|
+
fitres = compat.recordset_to_fitres(msg.content, True)
|
535
|
+
proxy = state.nid_to_proxies[msg.metadata.src_node_id]
|
536
|
+
state.legacy_results.append((proxy, fitres))
|
537
|
+
|
488
538
|
return self._check_threshold(state)
|
489
539
|
|
490
|
-
def unmask_stage( # pylint: disable=R0912, R0914
|
540
|
+
def unmask_stage( # pylint: disable=R0912, R0914, R0915
|
491
541
|
self, driver: Driver, context: LegacyContext, state: WorkflowState
|
492
542
|
) -> bool:
|
493
543
|
"""Execute the 'unmask' stage."""
|
@@ -516,12 +566,22 @@ class SecAggPlusWorkflow:
|
|
516
566
|
ttl="",
|
517
567
|
)
|
518
568
|
|
569
|
+
log(
|
570
|
+
DEBUG,
|
571
|
+
"[Stage 3] Requesting key shares from %s clients to remove masks.",
|
572
|
+
len(state.active_node_ids),
|
573
|
+
)
|
519
574
|
msgs = driver.send_and_receive(
|
520
575
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
521
576
|
)
|
522
577
|
state.active_node_ids = {
|
523
578
|
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
524
579
|
}
|
580
|
+
log(
|
581
|
+
DEBUG,
|
582
|
+
"[Stage 3] Received key shares from %s clients.",
|
583
|
+
len(state.active_node_ids),
|
584
|
+
)
|
525
585
|
|
526
586
|
# Build collected shares dict
|
527
587
|
collected_shares_dict: Dict[int, List[bytes]] = {}
|
@@ -534,7 +594,7 @@ class SecAggPlusWorkflow:
|
|
534
594
|
for owner_nid, share in zip(nids, shares):
|
535
595
|
collected_shares_dict[owner_nid].append(share)
|
536
596
|
|
537
|
-
# Remove
|
597
|
+
# Remove masks for every active client after collect_masked_vectors stage
|
538
598
|
masked_vector = state.aggregate_ndarrays
|
539
599
|
del state.aggregate_ndarrays
|
540
600
|
for nid, share_list in collected_shares_dict.items():
|
@@ -584,18 +644,30 @@ class SecAggPlusWorkflow:
|
|
584
644
|
for vec in aggregated_vector:
|
585
645
|
vec += offset
|
586
646
|
vec *= inv_dq_total_ratio
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
647
|
+
|
648
|
+
# Backward compatibility with Strategy
|
649
|
+
results = state.legacy_results
|
650
|
+
parameters = ndarrays_to_parameters(aggregated_vector)
|
651
|
+
for _, fitres in results:
|
652
|
+
fitres.parameters = parameters
|
653
|
+
|
654
|
+
# No exception/failure handling currently
|
655
|
+
log(
|
656
|
+
INFO,
|
657
|
+
"aggregate_fit: received %s results and %s failures",
|
658
|
+
len(results),
|
595
659
|
0,
|
596
|
-
driver.grpc_driver, # type: ignore
|
597
|
-
False,
|
598
|
-
driver.run_id, # type: ignore
|
599
660
|
)
|
600
|
-
context.strategy.aggregate_fit(current_round,
|
661
|
+
aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
|
662
|
+
parameters_aggregated, metrics_aggregated = aggregated_result
|
663
|
+
|
664
|
+
# Update the parameters and write history
|
665
|
+
if parameters_aggregated:
|
666
|
+
paramsrecord = compat.parameters_to_parametersrecord(
|
667
|
+
parameters_aggregated, True
|
668
|
+
)
|
669
|
+
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
670
|
+
context.history.add_metrics_distributed_fit(
|
671
|
+
server_round=current_round, metrics=metrics_aggregated
|
672
|
+
)
|
601
673
|
return True
|