flwr-nightly 1.8.0.dev20240310__py3-none-any.whl → 1.8.0.dev20240312__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/new.py +6 -3
- flwr/cli/utils.py +14 -1
- flwr/client/app.py +25 -2
- flwr/client/mod/__init__.py +2 -1
- flwr/client/mod/secure_aggregation/__init__.py +2 -0
- flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +42 -51
- flwr/common/logger.py +6 -8
- flwr/common/pyproject.py +41 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +2 -2
- flwr/server/superlink/state/in_memory_state.py +34 -32
- flwr/server/workflow/__init__.py +2 -1
- flwr/server/workflow/default_workflows.py +39 -40
- flwr/server/workflow/secure_aggregation/__init__.py +2 -0
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +98 -26
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/RECORD +21 -18
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/entry_points.txt +0 -0
flwr/cli/new/new.py
CHANGED
@@ -22,7 +22,7 @@ from typing import Dict, Optional
|
|
22
22
|
import typer
|
23
23
|
from typing_extensions import Annotated
|
24
24
|
|
25
|
-
from ..utils import prompt_options
|
25
|
+
from ..utils import prompt_options, prompt_text
|
26
26
|
|
27
27
|
|
28
28
|
class MlFramework(str, Enum):
|
@@ -72,9 +72,9 @@ def render_and_create(file_path: str, template: str, context: Dict[str, str]) ->
|
|
72
72
|
|
73
73
|
def new(
|
74
74
|
project_name: Annotated[
|
75
|
-
str,
|
75
|
+
Optional[str],
|
76
76
|
typer.Argument(metavar="project_name", help="The name of the project"),
|
77
|
-
],
|
77
|
+
] = None,
|
78
78
|
framework: Annotated[
|
79
79
|
Optional[MlFramework],
|
80
80
|
typer.Option(case_sensitive=False, help="The ML framework to use"),
|
@@ -83,6 +83,9 @@ def new(
|
|
83
83
|
"""Create new Flower project."""
|
84
84
|
print(f"Creating Flower project {project_name}...")
|
85
85
|
|
86
|
+
if project_name is None:
|
87
|
+
project_name = prompt_text("Please provide project name")
|
88
|
+
|
86
89
|
if framework is not None:
|
87
90
|
framework_str = str(framework.value)
|
88
91
|
else:
|
flwr/cli/utils.py
CHANGED
@@ -14,11 +14,24 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Flower command line interface utils."""
|
16
16
|
|
17
|
-
from typing import List
|
17
|
+
from typing import List, cast
|
18
18
|
|
19
19
|
import typer
|
20
20
|
|
21
21
|
|
22
|
+
def prompt_text(text: str) -> str:
|
23
|
+
"""Ask user to enter text input."""
|
24
|
+
while True:
|
25
|
+
result = typer.prompt(
|
26
|
+
typer.style(f"\n💬 {text}", fg=typer.colors.MAGENTA, bold=True)
|
27
|
+
)
|
28
|
+
if len(result) > 0:
|
29
|
+
break
|
30
|
+
print(typer.style("❌ Invalid entry", fg=typer.colors.RED, bold=True))
|
31
|
+
|
32
|
+
return cast(str, result)
|
33
|
+
|
34
|
+
|
22
35
|
def prompt_options(text: str, options: List[str]) -> str:
|
23
36
|
"""Ask user to select one of the given options and return the selected item."""
|
24
37
|
# Turn options into a list with index as in " [ 0] quickstart-pytorch"
|
flwr/client/app.py
CHANGED
@@ -456,7 +456,19 @@ def _start_client_internal(
|
|
456
456
|
time.sleep(3) # Wait for 3s before asking again
|
457
457
|
continue
|
458
458
|
|
459
|
-
log(INFO, "
|
459
|
+
log(INFO, "")
|
460
|
+
log(
|
461
|
+
INFO,
|
462
|
+
"[RUN %s, ROUND %s]",
|
463
|
+
message.metadata.run_id,
|
464
|
+
message.metadata.group_id,
|
465
|
+
)
|
466
|
+
log(
|
467
|
+
INFO,
|
468
|
+
"Received: %s message %s",
|
469
|
+
message.metadata.message_type,
|
470
|
+
message.metadata.message_id,
|
471
|
+
)
|
460
472
|
|
461
473
|
# Handle control message
|
462
474
|
out_message, sleep_duration = handle_control_message(message)
|
@@ -484,7 +496,18 @@ def _start_client_internal(
|
|
484
496
|
|
485
497
|
# Send
|
486
498
|
send(out_message)
|
487
|
-
log(
|
499
|
+
log(
|
500
|
+
INFO,
|
501
|
+
"[RUN %s, ROUND %s]",
|
502
|
+
out_message.metadata.run_id,
|
503
|
+
out_message.metadata.group_id,
|
504
|
+
)
|
505
|
+
log(
|
506
|
+
INFO,
|
507
|
+
"Sent: %s reply to message %s",
|
508
|
+
out_message.metadata.message_type,
|
509
|
+
message.metadata.message_id,
|
510
|
+
)
|
488
511
|
|
489
512
|
# Unregister node
|
490
513
|
if delete_node is not None:
|
flwr/client/mod/__init__.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
|
19
19
|
from .localdp_mod import LocalDpMod
|
20
|
-
from .secure_aggregation
|
20
|
+
from .secure_aggregation import secagg_mod, secaggplus_mod
|
21
21
|
from .utils import make_ffn
|
22
22
|
|
23
23
|
__all__ = [
|
@@ -25,5 +25,6 @@ __all__ = [
|
|
25
25
|
"fixedclipping_mod",
|
26
26
|
"LocalDpMod",
|
27
27
|
"make_ffn",
|
28
|
+
"secagg_mod",
|
28
29
|
"secaggplus_mod",
|
29
30
|
]
|
@@ -0,0 +1,30 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Modifier for the SecAgg protocol."""
|
16
|
+
|
17
|
+
|
18
|
+
from flwr.client.typing import ClientAppCallable
|
19
|
+
from flwr.common import Context, Message
|
20
|
+
|
21
|
+
from .secaggplus_mod import secaggplus_mod
|
22
|
+
|
23
|
+
|
24
|
+
def secagg_mod(
|
25
|
+
msg: Message,
|
26
|
+
ctxt: Context,
|
27
|
+
call_next: ClientAppCallable,
|
28
|
+
) -> Message:
|
29
|
+
"""Handle incoming message and return results, following the SecAgg protocol."""
|
30
|
+
return secaggplus_mod(msg, ctxt, call_next)
|
@@ -12,19 +12,20 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""
|
15
|
+
"""Modifier for the SecAgg+ protocol."""
|
16
16
|
|
17
17
|
|
18
18
|
import os
|
19
19
|
from dataclasses import dataclass, field
|
20
|
-
from logging import
|
21
|
-
from typing import Any,
|
20
|
+
from logging import DEBUG, WARNING
|
21
|
+
from typing import Any, Dict, List, Tuple, cast
|
22
22
|
|
23
23
|
from flwr.client.typing import ClientAppCallable
|
24
24
|
from flwr.common import (
|
25
25
|
ConfigsRecord,
|
26
26
|
Context,
|
27
27
|
Message,
|
28
|
+
Parameters,
|
28
29
|
RecordSet,
|
29
30
|
ndarray_to_bytes,
|
30
31
|
parameters_to_ndarrays,
|
@@ -62,7 +63,7 @@ from flwr.common.secure_aggregation.secaggplus_utils import (
|
|
62
63
|
share_keys_plaintext_concat,
|
63
64
|
share_keys_plaintext_separate,
|
64
65
|
)
|
65
|
-
from flwr.common.typing import ConfigsRecordValues
|
66
|
+
from flwr.common.typing import ConfigsRecordValues
|
66
67
|
|
67
68
|
|
68
69
|
@dataclass
|
@@ -132,18 +133,6 @@ class SecAggPlusState:
|
|
132
133
|
return ret
|
133
134
|
|
134
135
|
|
135
|
-
def _get_fit_fn(
|
136
|
-
msg: Message, ctxt: Context, call_next: ClientAppCallable
|
137
|
-
) -> Callable[[], FitRes]:
|
138
|
-
"""Get the fit function."""
|
139
|
-
|
140
|
-
def fit() -> FitRes:
|
141
|
-
out_msg = call_next(msg, ctxt)
|
142
|
-
return compat.recordset_to_fitres(out_msg.content, keep_input=False)
|
143
|
-
|
144
|
-
return fit
|
145
|
-
|
146
|
-
|
147
136
|
def secaggplus_mod(
|
148
137
|
msg: Message,
|
149
138
|
ctxt: Context,
|
@@ -173,25 +162,32 @@ def secaggplus_mod(
|
|
173
162
|
check_configs(state.current_stage, configs)
|
174
163
|
|
175
164
|
# Execute
|
165
|
+
out_content = RecordSet()
|
176
166
|
if state.current_stage == Stage.SETUP:
|
177
167
|
state.nid = msg.metadata.dst_node_id
|
178
168
|
res = _setup(state, configs)
|
179
169
|
elif state.current_stage == Stage.SHARE_KEYS:
|
180
170
|
res = _share_keys(state, configs)
|
181
|
-
elif state.current_stage == Stage.
|
182
|
-
|
183
|
-
|
171
|
+
elif state.current_stage == Stage.COLLECT_MASKED_VECTORS:
|
172
|
+
out_msg = call_next(msg, ctxt)
|
173
|
+
out_content = out_msg.content
|
174
|
+
fitres = compat.recordset_to_fitres(out_content, keep_input=True)
|
175
|
+
res = _collect_masked_vectors(
|
176
|
+
state, configs, fitres.num_examples, fitres.parameters
|
177
|
+
)
|
178
|
+
for p_record in out_content.parameters_records.values():
|
179
|
+
p_record.clear()
|
184
180
|
elif state.current_stage == Stage.UNMASK:
|
185
181
|
res = _unmask(state, configs)
|
186
182
|
else:
|
187
|
-
raise ValueError(f"Unknown
|
183
|
+
raise ValueError(f"Unknown SecAgg/SecAgg+ stage: {state.current_stage}")
|
188
184
|
|
189
185
|
# Save state
|
190
186
|
ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict())
|
191
187
|
|
192
188
|
# Return message
|
193
|
-
|
194
|
-
return msg.create_reply(
|
189
|
+
out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False)
|
190
|
+
return msg.create_reply(out_content, ttl="")
|
195
191
|
|
196
192
|
|
197
193
|
def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
@@ -199,7 +195,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
199
195
|
# Check the existence of Config.STAGE
|
200
196
|
if Key.STAGE not in configs:
|
201
197
|
raise KeyError(
|
202
|
-
f"The required key '{Key.STAGE}' is missing from the
|
198
|
+
f"The required key '{Key.STAGE}' is missing from the ConfigsRecord."
|
203
199
|
)
|
204
200
|
|
205
201
|
# Check the value type of the Config.STAGE
|
@@ -215,7 +211,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
215
211
|
if current_stage != Stage.UNMASK:
|
216
212
|
log(WARNING, "Restart from the setup stage")
|
217
213
|
# If stage is not "setup",
|
218
|
-
# the stage from
|
214
|
+
# the stage from configs should be the expected next stage
|
219
215
|
else:
|
220
216
|
stages = Stage.all()
|
221
217
|
expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
|
@@ -229,7 +225,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
229
225
|
# pylint: disable-next=too-many-branches
|
230
226
|
def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
231
227
|
"""Check the validity of the configs."""
|
232
|
-
# Check
|
228
|
+
# Check configs for the setup stage
|
233
229
|
if stage == Stage.SETUP:
|
234
230
|
key_type_pairs = [
|
235
231
|
(Key.SAMPLE_NUMBER, int),
|
@@ -243,7 +239,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
243
239
|
if key not in configs:
|
244
240
|
raise KeyError(
|
245
241
|
f"Stage {Stage.SETUP}: the required key '{key}' is "
|
246
|
-
"missing from the
|
242
|
+
"missing from the ConfigsRecord."
|
247
243
|
)
|
248
244
|
# Bool is a subclass of int in Python,
|
249
245
|
# so `isinstance(v, int)` will return True even if v is a boolean.
|
@@ -266,7 +262,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
266
262
|
f"Stage {Stage.SHARE_KEYS}: "
|
267
263
|
f"the value for the key '{key}' must be a list of two bytes."
|
268
264
|
)
|
269
|
-
elif stage == Stage.
|
265
|
+
elif stage == Stage.COLLECT_MASKED_VECTORS:
|
270
266
|
key_type_pairs = [
|
271
267
|
(Key.CIPHERTEXT_LIST, bytes),
|
272
268
|
(Key.SOURCE_LIST, int),
|
@@ -274,9 +270,9 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
274
270
|
for key, expected_type in key_type_pairs:
|
275
271
|
if key not in configs:
|
276
272
|
raise KeyError(
|
277
|
-
f"Stage {Stage.
|
273
|
+
f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
|
278
274
|
f"the required key '{key}' is "
|
279
|
-
"missing from the
|
275
|
+
"missing from the ConfigsRecord."
|
280
276
|
)
|
281
277
|
if not isinstance(configs[key], list) or any(
|
282
278
|
elm
|
@@ -285,7 +281,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
285
281
|
if type(elm) is not expected_type
|
286
282
|
):
|
287
283
|
raise TypeError(
|
288
|
-
f"Stage {Stage.
|
284
|
+
f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
|
289
285
|
f"the value for the key '{key}' "
|
290
286
|
f"must be of type List[{expected_type.__name__}]"
|
291
287
|
)
|
@@ -299,7 +295,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
299
295
|
raise KeyError(
|
300
296
|
f"Stage {Stage.UNMASK}: "
|
301
297
|
f"the required key '{key}' is "
|
302
|
-
"missing from the
|
298
|
+
"missing from the ConfigsRecord."
|
303
299
|
)
|
304
300
|
if not isinstance(configs[key], list) or any(
|
305
301
|
elm
|
@@ -322,7 +318,7 @@ def _setup(
|
|
322
318
|
# Assigning parameter values to object fields
|
323
319
|
sec_agg_param_dict = configs
|
324
320
|
state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
|
325
|
-
log(
|
321
|
+
log(DEBUG, "Node %d: starting stage 0...", state.nid)
|
326
322
|
|
327
323
|
state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
|
328
324
|
state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
|
@@ -347,7 +343,7 @@ def _setup(
|
|
347
343
|
|
348
344
|
state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
|
349
345
|
state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
|
350
|
-
log(
|
346
|
+
log(DEBUG, "Node %d: stage 0 completes. uploading public keys...", state.nid)
|
351
347
|
return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
|
352
348
|
|
353
349
|
|
@@ -357,7 +353,7 @@ def _share_keys(
|
|
357
353
|
) -> Dict[str, ConfigsRecordValues]:
|
358
354
|
named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs)
|
359
355
|
key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()}
|
360
|
-
log(
|
356
|
+
log(DEBUG, "Node %d: starting stage 1...", state.nid)
|
361
357
|
state.public_keys_dict = key_dict
|
362
358
|
|
363
359
|
# Check if the size is larger than threshold
|
@@ -409,17 +405,18 @@ def _share_keys(
|
|
409
405
|
dsts.append(nid)
|
410
406
|
ciphertexts.append(ciphertext)
|
411
407
|
|
412
|
-
log(
|
408
|
+
log(DEBUG, "Node %d: stage 1 completes. uploading key shares...", state.nid)
|
413
409
|
return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
|
414
410
|
|
415
411
|
|
416
412
|
# pylint: disable-next=too-many-locals
|
417
|
-
def
|
413
|
+
def _collect_masked_vectors(
|
418
414
|
state: SecAggPlusState,
|
419
415
|
configs: ConfigsRecord,
|
420
|
-
|
416
|
+
num_examples: int,
|
417
|
+
updated_parameters: Parameters,
|
421
418
|
) -> Dict[str, ConfigsRecordValues]:
|
422
|
-
log(
|
419
|
+
log(DEBUG, "Node %d: starting stage 2...", state.nid)
|
423
420
|
available_clients: List[int] = []
|
424
421
|
ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
|
425
422
|
srcs = cast(List[int], configs[Key.SOURCE_LIST])
|
@@ -447,26 +444,20 @@ def _collect_masked_input(
|
|
447
444
|
state.rd_seed_share_dict[src] = rd_seed_share
|
448
445
|
state.sk1_share_dict[src] = sk1_share
|
449
446
|
|
450
|
-
# Fit
|
451
|
-
|
452
|
-
if len(fit_res.metrics) > 0:
|
453
|
-
log(
|
454
|
-
WARNING,
|
455
|
-
"The metrics in FitRes will not be preserved or sent to the server.",
|
456
|
-
)
|
457
|
-
ratio = fit_res.num_examples / state.max_weight
|
447
|
+
# Fit
|
448
|
+
ratio = num_examples / state.max_weight
|
458
449
|
if ratio > 1:
|
459
450
|
log(
|
460
451
|
WARNING,
|
461
452
|
"Potential overflow warning: the provided weight (%s) exceeds the specified"
|
462
453
|
" max_weight (%s). This may lead to overflow issues.",
|
463
|
-
|
454
|
+
num_examples,
|
464
455
|
state.max_weight,
|
465
456
|
)
|
466
457
|
q_ratio = round(ratio * state.target_range)
|
467
458
|
dq_ratio = q_ratio / state.target_range
|
468
459
|
|
469
|
-
parameters = parameters_to_ndarrays(
|
460
|
+
parameters = parameters_to_ndarrays(updated_parameters)
|
470
461
|
parameters = parameters_multiply(parameters, dq_ratio)
|
471
462
|
|
472
463
|
# Quantize parameter update (vector)
|
@@ -500,7 +491,7 @@ def _collect_masked_input(
|
|
500
491
|
|
501
492
|
# Take mod of final weight update vector and return to server
|
502
493
|
quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
|
503
|
-
log(
|
494
|
+
log(DEBUG, "Node %d: stage 2 completed, uploading masked parameters...", state.nid)
|
504
495
|
return {
|
505
496
|
Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
|
506
497
|
}
|
@@ -509,7 +500,7 @@ def _collect_masked_input(
|
|
509
500
|
def _unmask(
|
510
501
|
state: SecAggPlusState, configs: ConfigsRecord
|
511
502
|
) -> Dict[str, ConfigsRecordValues]:
|
512
|
-
log(
|
503
|
+
log(DEBUG, "Node %d: starting stage 3...", state.nid)
|
513
504
|
|
514
505
|
active_nids = cast(List[int], configs[Key.ACTIVE_NODE_ID_LIST])
|
515
506
|
dead_nids = cast(List[int], configs[Key.DEAD_NODE_ID_LIST])
|
@@ -523,5 +514,5 @@ def _unmask(
|
|
523
514
|
shares += [state.rd_seed_share_dict[nid] for nid in active_nids]
|
524
515
|
shares += [state.sk1_share_dict[nid] for nid in dead_nids]
|
525
516
|
|
526
|
-
log(
|
517
|
+
log(DEBUG, "Node %d: stage 3 completes. uploading key shares...", state.nid)
|
527
518
|
return {Key.NODE_ID_LIST: all_nids, Key.SHARE_LIST: shares}
|
flwr/common/logger.py
CHANGED
@@ -168,11 +168,10 @@ def warn_experimental_feature(name: str) -> None:
|
|
168
168
|
"""Warn the user when they use an experimental feature."""
|
169
169
|
log(
|
170
170
|
WARN,
|
171
|
-
"""
|
172
|
-
EXPERIMENTAL FEATURE: %s
|
171
|
+
"""EXPERIMENTAL FEATURE: %s
|
173
172
|
|
174
|
-
|
175
|
-
|
173
|
+
This is an experimental feature. It could change significantly or be removed
|
174
|
+
entirely in future versions of Flower.
|
176
175
|
""",
|
177
176
|
name,
|
178
177
|
)
|
@@ -182,11 +181,10 @@ def warn_deprecated_feature(name: str) -> None:
|
|
182
181
|
"""Warn the user when they use a deprecated feature."""
|
183
182
|
log(
|
184
183
|
WARN,
|
185
|
-
"""
|
186
|
-
DEPRECATED FEATURE: %s
|
184
|
+
"""DEPRECATED FEATURE: %s
|
187
185
|
|
188
|
-
|
189
|
-
|
186
|
+
This is a deprecated feature. It will be removed
|
187
|
+
entirely in future versions of Flower.
|
190
188
|
""",
|
191
189
|
name,
|
192
190
|
)
|
flwr/common/pyproject.py
ADDED
@@ -0,0 +1,41 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Validates the project's name property."""
|
16
|
+
|
17
|
+
import re
|
18
|
+
|
19
|
+
|
20
|
+
def validate_project_name(name: str) -> bool:
|
21
|
+
"""Validate the project name against PEP 621 and PEP 503 specifications.
|
22
|
+
|
23
|
+
Conventions at a glance:
|
24
|
+
- Must be lowercase
|
25
|
+
- Must not contain special characters
|
26
|
+
- Must use hyphens(recommended) or underscores. No spaces.
|
27
|
+
- Recommended to be no more than 40 characters long (But it can be)
|
28
|
+
|
29
|
+
Parameters
|
30
|
+
----------
|
31
|
+
name : str
|
32
|
+
The project name to validate.
|
33
|
+
|
34
|
+
Returns
|
35
|
+
-------
|
36
|
+
bool
|
37
|
+
True if the name is valid, False otherwise.
|
38
|
+
"""
|
39
|
+
if not name or len(name) > 40 or not re.match(r"^[a-z0-9-_]+$", name):
|
40
|
+
return False
|
41
|
+
return True
|
@@ -27,9 +27,9 @@ class Stage:
|
|
27
27
|
|
28
28
|
SETUP = "setup"
|
29
29
|
SHARE_KEYS = "share_keys"
|
30
|
-
|
30
|
+
COLLECT_MASKED_VECTORS = "collect_masked_vectors"
|
31
31
|
UNMASK = "unmask"
|
32
|
-
_stages = (SETUP, SHARE_KEYS,
|
32
|
+
_stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_VECTORS, UNMASK)
|
33
33
|
|
34
34
|
@classmethod
|
35
35
|
def all(cls) -> tuple[str, str, str, str]:
|
@@ -122,7 +122,8 @@ class InMemoryState(State):
|
|
122
122
|
task_res.task_id = str(task_id)
|
123
123
|
task_res.task.created_at = created_at.isoformat()
|
124
124
|
task_res.task.ttl = ttl.isoformat()
|
125
|
-
self.
|
125
|
+
with self.lock:
|
126
|
+
self.task_res_store[task_id] = task_res
|
126
127
|
|
127
128
|
# Return the new task_id
|
128
129
|
return task_id
|
@@ -132,46 +133,47 @@ class InMemoryState(State):
|
|
132
133
|
if limit is not None and limit < 1:
|
133
134
|
raise AssertionError("`limit` must be >= 1")
|
134
135
|
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
136
|
+
with self.lock:
|
137
|
+
# Find TaskRes that were not delivered yet
|
138
|
+
task_res_list: List[TaskRes] = []
|
139
|
+
for _, task_res in self.task_res_store.items():
|
140
|
+
if (
|
141
|
+
UUID(task_res.task.ancestry[0]) in task_ids
|
142
|
+
and task_res.task.delivered_at == ""
|
143
|
+
):
|
144
|
+
task_res_list.append(task_res)
|
145
|
+
if limit and len(task_res_list) == limit:
|
146
|
+
break
|
145
147
|
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
148
|
+
# Mark all of them as delivered
|
149
|
+
delivered_at = now().isoformat()
|
150
|
+
for task_res in task_res_list:
|
151
|
+
task_res.task.delivered_at = delivered_at
|
150
152
|
|
151
|
-
|
152
|
-
|
153
|
+
# Return TaskRes
|
154
|
+
return task_res_list
|
153
155
|
|
154
156
|
def delete_tasks(self, task_ids: Set[UUID]) -> None:
|
155
157
|
"""Delete all delivered TaskIns/TaskRes pairs."""
|
156
158
|
task_ins_to_be_deleted: Set[UUID] = set()
|
157
159
|
task_res_to_be_deleted: Set[UUID] = set()
|
158
160
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
161
|
+
with self.lock:
|
162
|
+
for task_ins_id in task_ids:
|
163
|
+
# Find the task_id of the matching task_res
|
164
|
+
for task_res_id, task_res in self.task_res_store.items():
|
165
|
+
if UUID(task_res.task.ancestry[0]) != task_ins_id:
|
166
|
+
continue
|
167
|
+
if task_res.task.delivered_at == "":
|
168
|
+
continue
|
169
|
+
|
170
|
+
task_ins_to_be_deleted.add(task_ins_id)
|
171
|
+
task_res_to_be_deleted.add(task_res_id)
|
172
|
+
|
173
|
+
for task_id in task_ins_to_be_deleted:
|
172
174
|
del self.task_ins_store[task_id]
|
173
|
-
|
174
|
-
|
175
|
+
for task_id in task_res_to_be_deleted:
|
176
|
+
del self.task_res_store[task_id]
|
175
177
|
|
176
178
|
def num_task_ins(self) -> int:
|
177
179
|
"""Calculate the number of task_ins in store.
|
flwr/server/workflow/__init__.py
CHANGED
@@ -16,9 +16,10 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from .default_workflows import DefaultWorkflow
|
19
|
-
from .secure_aggregation import SecAggPlusWorkflow
|
19
|
+
from .secure_aggregation import SecAggPlusWorkflow, SecAggWorkflow
|
20
20
|
|
21
21
|
__all__ = [
|
22
22
|
"DefaultWorkflow",
|
23
23
|
"SecAggPlusWorkflow",
|
24
|
+
"SecAggWorkflow",
|
24
25
|
]
|