flwr-nightly 1.8.0.dev20240310__py3-none-any.whl → 1.8.0.dev20240312__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/cli/new/new.py +6 -3
- flwr/cli/utils.py +14 -1
- flwr/client/app.py +25 -2
- flwr/client/mod/__init__.py +2 -1
- flwr/client/mod/secure_aggregation/__init__.py +2 -0
- flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +42 -51
- flwr/common/logger.py +6 -8
- flwr/common/pyproject.py +41 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +2 -2
- flwr/server/superlink/state/in_memory_state.py +34 -32
- flwr/server/workflow/__init__.py +2 -1
- flwr/server/workflow/default_workflows.py +39 -40
- flwr/server/workflow/secure_aggregation/__init__.py +2 -0
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +98 -26
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/RECORD +21 -18
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/entry_points.txt +0 -0
@@ -15,8 +15,9 @@
|
|
15
15
|
"""Legacy default workflows."""
|
16
16
|
|
17
17
|
|
18
|
+
import io
|
18
19
|
import timeit
|
19
|
-
from logging import
|
20
|
+
from logging import INFO
|
20
21
|
from typing import Optional, cast
|
21
22
|
|
22
23
|
import flwr.common.recordset_compat as compat
|
@@ -58,16 +59,18 @@ class DefaultWorkflow:
|
|
58
59
|
)
|
59
60
|
|
60
61
|
# Initialize parameters
|
62
|
+
log(INFO, "[INIT]")
|
61
63
|
default_init_params_workflow(driver, context)
|
62
64
|
|
63
65
|
# Run federated learning for num_rounds
|
64
|
-
log(INFO, "FL starting")
|
65
66
|
start_time = timeit.default_timer()
|
66
67
|
cfg = ConfigsRecord()
|
67
68
|
cfg[Key.START_TIME] = start_time
|
68
69
|
context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg
|
69
70
|
|
70
71
|
for current_round in range(1, context.config.num_rounds + 1):
|
72
|
+
log(INFO, "")
|
73
|
+
log(INFO, "[ROUND %s]", current_round)
|
71
74
|
cfg[Key.CURRENT_ROUND] = current_round
|
72
75
|
|
73
76
|
# Fit round
|
@@ -79,22 +82,19 @@ class DefaultWorkflow:
|
|
79
82
|
# Evaluate round
|
80
83
|
self.evaluate_workflow(driver, context)
|
81
84
|
|
82
|
-
# Bookkeeping
|
85
|
+
# Bookkeeping and log results
|
83
86
|
end_time = timeit.default_timer()
|
84
87
|
elapsed = end_time - start_time
|
85
|
-
log(INFO, "FL finished in %s", elapsed)
|
86
|
-
|
87
|
-
# Log results
|
88
88
|
hist = context.history
|
89
|
-
log(INFO, "
|
90
|
-
log(
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
log(INFO, "
|
89
|
+
log(INFO, "")
|
90
|
+
log(INFO, "[SUMMARY]")
|
91
|
+
log(INFO, "Run finished %s rounds in %.2fs", context.config.num_rounds, elapsed)
|
92
|
+
for idx, line in enumerate(io.StringIO(str(hist))):
|
93
|
+
if idx == 0:
|
94
|
+
log(INFO, "%s", line.strip("\n"))
|
95
|
+
else:
|
96
|
+
log(INFO, "\t%s", line.strip("\n"))
|
97
|
+
log(INFO, "")
|
98
98
|
|
99
99
|
# Terminate the thread
|
100
100
|
f_stop.set()
|
@@ -107,12 +107,11 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
107
107
|
if not isinstance(context, LegacyContext):
|
108
108
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
109
109
|
|
110
|
-
log(INFO, "Initializing global parameters")
|
111
110
|
parameters = context.strategy.initialize_parameters(
|
112
111
|
client_manager=context.client_manager
|
113
112
|
)
|
114
113
|
if parameters is not None:
|
115
|
-
log(INFO, "Using initial parameters provided by strategy")
|
114
|
+
log(INFO, "Using initial global parameters provided by strategy")
|
116
115
|
paramsrecord = compat.parameters_to_parametersrecord(
|
117
116
|
parameters, keep_input=True
|
118
117
|
)
|
@@ -128,7 +127,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
128
127
|
content=content,
|
129
128
|
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
130
129
|
dst_node_id=random_client.node_id,
|
131
|
-
group_id="",
|
130
|
+
group_id="0",
|
132
131
|
ttl="",
|
133
132
|
)
|
134
133
|
]
|
@@ -140,7 +139,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
140
139
|
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
141
140
|
|
142
141
|
# Evaluate initial parameters
|
143
|
-
log(INFO, "Evaluating initial parameters")
|
142
|
+
log(INFO, "Evaluating initial global parameters")
|
144
143
|
parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
|
145
144
|
res = context.strategy.evaluate(0, parameters=parameters)
|
146
145
|
if res is not None:
|
@@ -186,7 +185,9 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
|
|
186
185
|
)
|
187
186
|
|
188
187
|
|
189
|
-
def default_fit_workflow(
|
188
|
+
def default_fit_workflow( # pylint: disable=R0914
|
189
|
+
driver: Driver, context: Context
|
190
|
+
) -> None:
|
190
191
|
"""Execute the default workflow for a single fit round."""
|
191
192
|
if not isinstance(context, LegacyContext):
|
192
193
|
raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
|
@@ -207,12 +208,11 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
|
|
207
208
|
)
|
208
209
|
|
209
210
|
if not client_instructions:
|
210
|
-
log(INFO, "
|
211
|
+
log(INFO, "configure_fit: no clients selected, cancel")
|
211
212
|
return
|
212
213
|
log(
|
213
|
-
|
214
|
-
"
|
215
|
-
current_round,
|
214
|
+
INFO,
|
215
|
+
"configure_fit: strategy sampled %s clients (out of %s)",
|
216
216
|
len(client_instructions),
|
217
217
|
context.client_manager.num_available(),
|
218
218
|
)
|
@@ -226,7 +226,7 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
|
|
226
226
|
content=compat.fitins_to_recordset(fitins, True),
|
227
227
|
message_type=MessageType.TRAIN,
|
228
228
|
dst_node_id=proxy.node_id,
|
229
|
-
group_id=
|
229
|
+
group_id=str(current_round),
|
230
230
|
ttl="",
|
231
231
|
)
|
232
232
|
for proxy, fitins in client_instructions
|
@@ -236,14 +236,14 @@ def default_fit_workflow(driver: Driver, context: Context) -> None:
|
|
236
236
|
# collect `fit` results from all clients participating in this round
|
237
237
|
messages = list(driver.send_and_receive(out_messages))
|
238
238
|
del out_messages
|
239
|
+
num_failures = len([msg for msg in messages if msg.has_error()])
|
239
240
|
|
240
241
|
# No exception/failure handling currently
|
241
242
|
log(
|
242
|
-
|
243
|
-
"
|
244
|
-
|
245
|
-
|
246
|
-
0,
|
243
|
+
INFO,
|
244
|
+
"aggregate_fit: received %s results and %s failures",
|
245
|
+
len(messages) - num_failures,
|
246
|
+
num_failures,
|
247
247
|
)
|
248
248
|
|
249
249
|
# Aggregate training results
|
@@ -288,12 +288,11 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
288
288
|
client_manager=context.client_manager,
|
289
289
|
)
|
290
290
|
if not client_instructions:
|
291
|
-
log(INFO, "
|
291
|
+
log(INFO, "configure_evaluate: no clients selected, skipping evaluation")
|
292
292
|
return
|
293
293
|
log(
|
294
|
-
|
295
|
-
"
|
296
|
-
current_round,
|
294
|
+
INFO,
|
295
|
+
"configure_evaluate: strategy sampled %s clients (out of %s)",
|
297
296
|
len(client_instructions),
|
298
297
|
context.client_manager.num_available(),
|
299
298
|
)
|
@@ -307,7 +306,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
307
306
|
content=compat.evaluateins_to_recordset(evalins, True),
|
308
307
|
message_type=MessageType.EVALUATE,
|
309
308
|
dst_node_id=proxy.node_id,
|
310
|
-
group_id=
|
309
|
+
group_id=str(current_round),
|
311
310
|
ttl="",
|
312
311
|
)
|
313
312
|
for proxy, evalins in client_instructions
|
@@ -317,14 +316,14 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
317
316
|
# collect `evaluate` results from all clients participating in this round
|
318
317
|
messages = list(driver.send_and_receive(out_messages))
|
319
318
|
del out_messages
|
319
|
+
num_failures = len([msg for msg in messages if msg.has_error()])
|
320
320
|
|
321
321
|
# No exception/failure handling currently
|
322
322
|
log(
|
323
|
-
|
324
|
-
"
|
325
|
-
|
326
|
-
|
327
|
-
0,
|
323
|
+
INFO,
|
324
|
+
"aggregate_evaluate: received %s results and %s failures",
|
325
|
+
len(messages) - num_failures,
|
326
|
+
num_failures,
|
328
327
|
)
|
329
328
|
|
330
329
|
# Aggregate the evaluation results
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Workflow for the SecAgg protocol."""
|
16
|
+
|
17
|
+
|
18
|
+
from typing import Optional, Union
|
19
|
+
|
20
|
+
from .secaggplus_workflow import SecAggPlusWorkflow
|
21
|
+
|
22
|
+
|
23
|
+
class SecAggWorkflow(SecAggPlusWorkflow):
|
24
|
+
"""The workflow for the SecAgg protocol.
|
25
|
+
|
26
|
+
The SecAgg protocol ensures the secure summation of integer vectors owned by
|
27
|
+
multiple parties, without accessing any individual integer vector. This workflow
|
28
|
+
allows the server to compute the weighted average of model parameters across all
|
29
|
+
clients, ensuring individual contributions remain private. This is achieved by
|
30
|
+
clients sending both, a weighting factor and a weighted version of the locally
|
31
|
+
updated parameters, both of which are masked for privacy. Specifically, each
|
32
|
+
client uploads "[w, w * params]" with masks, where weighting factor 'w' is the
|
33
|
+
number of examples ('num_examples') and 'params' represents the model parameters
|
34
|
+
('parameters') from the client's `FitRes`. The server then aggregates these
|
35
|
+
contributions to compute the weighted average of model parameters.
|
36
|
+
|
37
|
+
The protocol involves four main stages:
|
38
|
+
- 'setup': Send SecAgg configuration to clients and collect their public keys.
|
39
|
+
- 'share keys': Broadcast public keys among clients and collect encrypted secret
|
40
|
+
key shares.
|
41
|
+
- 'collect masked vectors': Forward encrypted secret key shares to target clients
|
42
|
+
and collect masked model parameters.
|
43
|
+
- 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
|
44
|
+
|
45
|
+
Only the aggregated model parameters are exposed and passed to
|
46
|
+
`Strategy.aggregate_fit`, ensuring individual data privacy.
|
47
|
+
|
48
|
+
Parameters
|
49
|
+
----------
|
50
|
+
reconstruction_threshold : Union[int, float]
|
51
|
+
The minimum number of shares required to reconstruct a client's private key,
|
52
|
+
or, if specified as a float, it represents the proportion of the total number
|
53
|
+
of shares needed for reconstruction. This threshold ensures privacy by allowing
|
54
|
+
for the recovery of contributions from dropped clients during aggregation,
|
55
|
+
without compromising individual client data.
|
56
|
+
max_weight : Optional[float] (default: 1000.0)
|
57
|
+
The maximum value of the weight that can be assigned to any single client's
|
58
|
+
update during the weighted average calculation on the server side, e.g., in the
|
59
|
+
FedAvg algorithm.
|
60
|
+
clipping_range : float, optional (default: 8.0)
|
61
|
+
The range within which model parameters are clipped before quantization.
|
62
|
+
This parameter ensures each model parameter is bounded within
|
63
|
+
[-clipping_range, clipping_range], facilitating quantization.
|
64
|
+
quantization_range : int, optional (default: 4194304, this equals 2**22)
|
65
|
+
The size of the range into which floating-point model parameters are quantized,
|
66
|
+
mapping each parameter to an integer in [0, quantization_range-1]. This
|
67
|
+
facilitates cryptographic operations on the model updates.
|
68
|
+
modulus_range : int, optional (default: 4294967296, this equals 2**32)
|
69
|
+
The range of values from which random mask entries are uniformly sampled
|
70
|
+
([0, modulus_range-1]). `modulus_range` must be less than 4294967296.
|
71
|
+
Please use 2**n values for `modulus_range` to prevent overflow issues.
|
72
|
+
timeout : Optional[float] (default: None)
|
73
|
+
The timeout duration in seconds. If specified, the workflow will wait for
|
74
|
+
replies for this duration each time. If `None`, there is no time limit and
|
75
|
+
the workflow will wait until replies for all messages are received.
|
76
|
+
|
77
|
+
Notes
|
78
|
+
-----
|
79
|
+
- Each client's private key is split into N shares under the SecAgg protocol, where
|
80
|
+
N is the number of selected clients.
|
81
|
+
- Generally, higher `reconstruction_threshold` means better privacy guarantees but
|
82
|
+
less tolerance to dropouts.
|
83
|
+
- Too large `max_weight` may compromise the precision of the quantization.
|
84
|
+
- `modulus_range` must be 2**n and larger than `quantization_range`.
|
85
|
+
- When `reconstruction_threshold` is a float, it is interpreted as the proportion of
|
86
|
+
the number of all selected clients needed for the reconstruction of a private key.
|
87
|
+
This feature enables flexibility in setting the security threshold relative to the
|
88
|
+
number of selected clients.
|
89
|
+
- `reconstruction_threshold`, and the quantization parameters
|
90
|
+
(`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in
|
91
|
+
balancing privacy, robustness, and efficiency within the SecAgg protocol.
|
92
|
+
"""
|
93
|
+
|
94
|
+
def __init__( # pylint: disable=R0913
|
95
|
+
self,
|
96
|
+
reconstruction_threshold: Union[int, float],
|
97
|
+
*,
|
98
|
+
max_weight: float = 1000.0,
|
99
|
+
clipping_range: float = 8.0,
|
100
|
+
quantization_range: int = 4194304,
|
101
|
+
modulus_range: int = 4294967296,
|
102
|
+
timeout: Optional[float] = None,
|
103
|
+
) -> None:
|
104
|
+
super().__init__(
|
105
|
+
num_shares=1.0,
|
106
|
+
reconstruction_threshold=reconstruction_threshold,
|
107
|
+
max_weight=max_weight,
|
108
|
+
clipping_range=clipping_range,
|
109
|
+
quantization_range=quantization_range,
|
110
|
+
modulus_range=modulus_range,
|
111
|
+
timeout=timeout,
|
112
|
+
)
|
@@ -17,12 +17,11 @@
|
|
17
17
|
|
18
18
|
import random
|
19
19
|
from dataclasses import dataclass, field
|
20
|
-
from logging import ERROR, WARN
|
21
|
-
from typing import Dict, List, Optional, Set, Union, cast
|
20
|
+
from logging import DEBUG, ERROR, INFO, WARN
|
21
|
+
from typing import Dict, List, Optional, Set, Tuple, Union, cast
|
22
22
|
|
23
23
|
import flwr.common.recordset_compat as compat
|
24
24
|
from flwr.common import (
|
25
|
-
Code,
|
26
25
|
ConfigsRecord,
|
27
26
|
Context,
|
28
27
|
FitRes,
|
@@ -30,7 +29,6 @@ from flwr.common import (
|
|
30
29
|
MessageType,
|
31
30
|
NDArrays,
|
32
31
|
RecordSet,
|
33
|
-
Status,
|
34
32
|
bytes_to_ndarray,
|
35
33
|
log,
|
36
34
|
ndarrays_to_parameters,
|
@@ -55,7 +53,7 @@ from flwr.common.secure_aggregation.secaggplus_constants import (
|
|
55
53
|
Stage,
|
56
54
|
)
|
57
55
|
from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
|
58
|
-
from flwr.server.
|
56
|
+
from flwr.server.client_proxy import ClientProxy
|
59
57
|
from flwr.server.compat.legacy_context import LegacyContext
|
60
58
|
from flwr.server.driver import Driver
|
61
59
|
|
@@ -67,6 +65,7 @@ from ..constant import Key as WorkflowKey
|
|
67
65
|
class WorkflowState: # pylint: disable=R0902
|
68
66
|
"""The state of the SecAgg+ protocol."""
|
69
67
|
|
68
|
+
nid_to_proxies: Dict[int, ClientProxy] = field(default_factory=dict)
|
70
69
|
nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict)
|
71
70
|
sampled_node_ids: Set[int] = field(default_factory=set)
|
72
71
|
active_node_ids: Set[int] = field(default_factory=set)
|
@@ -81,6 +80,7 @@ class WorkflowState: # pylint: disable=R0902
|
|
81
80
|
forward_srcs: Dict[int, List[int]] = field(default_factory=dict)
|
82
81
|
forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
|
83
82
|
aggregate_ndarrays: NDArrays = field(default_factory=list)
|
83
|
+
legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
|
84
84
|
|
85
85
|
|
86
86
|
class SecAggPlusWorkflow:
|
@@ -101,7 +101,7 @@ class SecAggPlusWorkflow:
|
|
101
101
|
- 'setup': Send SecAgg+ configuration to clients and collect their public keys.
|
102
102
|
- 'share keys': Broadcast public keys among clients and collect encrypted secret
|
103
103
|
key shares.
|
104
|
-
- 'collect masked
|
104
|
+
- 'collect masked vectors': Forward encrypted secret key shares to target clients
|
105
105
|
and collect masked model parameters.
|
106
106
|
- 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
|
107
107
|
|
@@ -195,12 +195,15 @@ class SecAggPlusWorkflow:
|
|
195
195
|
steps = (
|
196
196
|
self.setup_stage,
|
197
197
|
self.share_keys_stage,
|
198
|
-
self.
|
198
|
+
self.collect_masked_vectors_stage,
|
199
199
|
self.unmask_stage,
|
200
200
|
)
|
201
|
+
log(INFO, "Secure aggregation commencing.")
|
201
202
|
for step in steps:
|
202
203
|
if not step(driver, context, state):
|
204
|
+
log(INFO, "Secure aggregation halted.")
|
203
205
|
return
|
206
|
+
log(INFO, "Secure aggregation completed.")
|
204
207
|
|
205
208
|
def _check_init_params(self) -> None: # pylint: disable=R0912
|
206
209
|
# Check `num_shares`
|
@@ -287,10 +290,21 @@ class SecAggPlusWorkflow:
|
|
287
290
|
proxy_fitins_lst = context.strategy.configure_fit(
|
288
291
|
current_round, parameters, context.client_manager
|
289
292
|
)
|
293
|
+
if not proxy_fitins_lst:
|
294
|
+
log(INFO, "configure_fit: no clients selected, cancel")
|
295
|
+
return False
|
296
|
+
log(
|
297
|
+
INFO,
|
298
|
+
"configure_fit: strategy sampled %s clients (out of %s)",
|
299
|
+
len(proxy_fitins_lst),
|
300
|
+
context.client_manager.num_available(),
|
301
|
+
)
|
302
|
+
|
290
303
|
state.nid_to_fitins = {
|
291
|
-
proxy.node_id: compat.fitins_to_recordset(fitins,
|
304
|
+
proxy.node_id: compat.fitins_to_recordset(fitins, True)
|
292
305
|
for proxy, fitins in proxy_fitins_lst
|
293
306
|
}
|
307
|
+
state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst}
|
294
308
|
|
295
309
|
# Protocol config
|
296
310
|
sampled_node_ids = list(state.nid_to_fitins.keys())
|
@@ -362,12 +376,22 @@ class SecAggPlusWorkflow:
|
|
362
376
|
ttl="",
|
363
377
|
)
|
364
378
|
|
379
|
+
log(
|
380
|
+
DEBUG,
|
381
|
+
"[Stage 0] Sending configurations to %s clients.",
|
382
|
+
len(state.active_node_ids),
|
383
|
+
)
|
365
384
|
msgs = driver.send_and_receive(
|
366
385
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
367
386
|
)
|
368
387
|
state.active_node_ids = {
|
369
388
|
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
370
389
|
}
|
390
|
+
log(
|
391
|
+
DEBUG,
|
392
|
+
"[Stage 0] Received public keys from %s clients.",
|
393
|
+
len(state.active_node_ids),
|
394
|
+
)
|
371
395
|
|
372
396
|
for msg in msgs:
|
373
397
|
if msg.has_error():
|
@@ -401,12 +425,22 @@ class SecAggPlusWorkflow:
|
|
401
425
|
)
|
402
426
|
|
403
427
|
# Broadcast public keys to clients and receive secret key shares
|
428
|
+
log(
|
429
|
+
DEBUG,
|
430
|
+
"[Stage 1] Forwarding public keys to %s clients.",
|
431
|
+
len(state.active_node_ids),
|
432
|
+
)
|
404
433
|
msgs = driver.send_and_receive(
|
405
434
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
406
435
|
)
|
407
436
|
state.active_node_ids = {
|
408
437
|
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
409
438
|
}
|
439
|
+
log(
|
440
|
+
DEBUG,
|
441
|
+
"[Stage 1] Received encrypted key shares from %s clients.",
|
442
|
+
len(state.active_node_ids),
|
443
|
+
)
|
410
444
|
|
411
445
|
# Build forward packet list dictionary
|
412
446
|
srcs: List[int] = []
|
@@ -437,16 +471,16 @@ class SecAggPlusWorkflow:
|
|
437
471
|
|
438
472
|
return self._check_threshold(state)
|
439
473
|
|
440
|
-
def
|
474
|
+
def collect_masked_vectors_stage(
|
441
475
|
self, driver: Driver, context: LegacyContext, state: WorkflowState
|
442
476
|
) -> bool:
|
443
|
-
"""Execute the 'collect masked
|
477
|
+
"""Execute the 'collect masked vectors' stage."""
|
444
478
|
cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
|
445
479
|
|
446
|
-
# Send secret key shares to clients (plus FitIns) and collect masked
|
480
|
+
# Send secret key shares to clients (plus FitIns) and collect masked vectors
|
447
481
|
def make(nid: int) -> Message:
|
448
482
|
cfgs_dict = {
|
449
|
-
Key.STAGE: Stage.
|
483
|
+
Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
|
450
484
|
Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
|
451
485
|
Key.SOURCE_LIST: state.forward_srcs[nid],
|
452
486
|
}
|
@@ -461,12 +495,22 @@ class SecAggPlusWorkflow:
|
|
461
495
|
ttl="",
|
462
496
|
)
|
463
497
|
|
498
|
+
log(
|
499
|
+
DEBUG,
|
500
|
+
"[Stage 2] Forwarding encrypted key shares to %s clients.",
|
501
|
+
len(state.active_node_ids),
|
502
|
+
)
|
464
503
|
msgs = driver.send_and_receive(
|
465
504
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
466
505
|
)
|
467
506
|
state.active_node_ids = {
|
468
507
|
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
469
508
|
}
|
509
|
+
log(
|
510
|
+
DEBUG,
|
511
|
+
"[Stage 2] Received masked vectors from %s clients.",
|
512
|
+
len(state.active_node_ids),
|
513
|
+
)
|
470
514
|
|
471
515
|
# Clear cache
|
472
516
|
del state.forward_ciphertexts, state.forward_srcs, state.nid_to_fitins
|
@@ -485,9 +529,15 @@ class SecAggPlusWorkflow:
|
|
485
529
|
masked_vector = parameters_mod(masked_vector, state.mod_range)
|
486
530
|
state.aggregate_ndarrays = masked_vector
|
487
531
|
|
532
|
+
# Backward compatibility with Strategy
|
533
|
+
for msg in msgs:
|
534
|
+
fitres = compat.recordset_to_fitres(msg.content, True)
|
535
|
+
proxy = state.nid_to_proxies[msg.metadata.src_node_id]
|
536
|
+
state.legacy_results.append((proxy, fitres))
|
537
|
+
|
488
538
|
return self._check_threshold(state)
|
489
539
|
|
490
|
-
def unmask_stage( # pylint: disable=R0912, R0914
|
540
|
+
def unmask_stage( # pylint: disable=R0912, R0914, R0915
|
491
541
|
self, driver: Driver, context: LegacyContext, state: WorkflowState
|
492
542
|
) -> bool:
|
493
543
|
"""Execute the 'unmask' stage."""
|
@@ -516,12 +566,22 @@ class SecAggPlusWorkflow:
|
|
516
566
|
ttl="",
|
517
567
|
)
|
518
568
|
|
569
|
+
log(
|
570
|
+
DEBUG,
|
571
|
+
"[Stage 3] Requesting key shares from %s clients to remove masks.",
|
572
|
+
len(state.active_node_ids),
|
573
|
+
)
|
519
574
|
msgs = driver.send_and_receive(
|
520
575
|
[make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
|
521
576
|
)
|
522
577
|
state.active_node_ids = {
|
523
578
|
msg.metadata.src_node_id for msg in msgs if not msg.has_error()
|
524
579
|
}
|
580
|
+
log(
|
581
|
+
DEBUG,
|
582
|
+
"[Stage 3] Received key shares from %s clients.",
|
583
|
+
len(state.active_node_ids),
|
584
|
+
)
|
525
585
|
|
526
586
|
# Build collected shares dict
|
527
587
|
collected_shares_dict: Dict[int, List[bytes]] = {}
|
@@ -534,7 +594,7 @@ class SecAggPlusWorkflow:
|
|
534
594
|
for owner_nid, share in zip(nids, shares):
|
535
595
|
collected_shares_dict[owner_nid].append(share)
|
536
596
|
|
537
|
-
# Remove
|
597
|
+
# Remove masks for every active client after collect_masked_vectors stage
|
538
598
|
masked_vector = state.aggregate_ndarrays
|
539
599
|
del state.aggregate_ndarrays
|
540
600
|
for nid, share_list in collected_shares_dict.items():
|
@@ -584,18 +644,30 @@ class SecAggPlusWorkflow:
|
|
584
644
|
for vec in aggregated_vector:
|
585
645
|
vec += offset
|
586
646
|
vec *= inv_dq_total_ratio
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
647
|
+
|
648
|
+
# Backward compatibility with Strategy
|
649
|
+
results = state.legacy_results
|
650
|
+
parameters = ndarrays_to_parameters(aggregated_vector)
|
651
|
+
for _, fitres in results:
|
652
|
+
fitres.parameters = parameters
|
653
|
+
|
654
|
+
# No exception/failure handling currently
|
655
|
+
log(
|
656
|
+
INFO,
|
657
|
+
"aggregate_fit: received %s results and %s failures",
|
658
|
+
len(results),
|
595
659
|
0,
|
596
|
-
driver.grpc_driver, # type: ignore
|
597
|
-
False,
|
598
|
-
driver.run_id, # type: ignore
|
599
660
|
)
|
600
|
-
context.strategy.aggregate_fit(current_round,
|
661
|
+
aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
|
662
|
+
parameters_aggregated, metrics_aggregated = aggregated_result
|
663
|
+
|
664
|
+
# Update the parameters and write history
|
665
|
+
if parameters_aggregated:
|
666
|
+
paramsrecord = compat.parameters_to_parametersrecord(
|
667
|
+
parameters_aggregated, True
|
668
|
+
)
|
669
|
+
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
670
|
+
context.history.add_metrics_distributed_fit(
|
671
|
+
server_round=current_round, metrics=metrics_aggregated
|
672
|
+
)
|
601
673
|
return True
|