flwr-nightly 1.8.0.dev20240228__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
- flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
- flwr/proto/error_pb2.py +26 -0
- flwr/proto/error_pb2.pyi +25 -0
- flwr/proto/error_pb2_grpc.py +4 -0
- flwr/proto/error_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +8 -7
- flwr/proto/task_pb2.pyi +7 -2
- flwr/server/strategy/__init__.py +11 -3
- flwr/server/strategy/dp_adaptive_clipping.py +205 -1
- flwr/server/superlink/driver/driver_servicer.py +4 -4
- flwr/simulation/__init__.py +3 -0
- flwr/simulation/run_simulation.py +177 -0
- {flwr_nightly-1.8.0.dev20240228.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240228.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +18 -13
- {flwr_nightly-1.8.0.dev20240228.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.8.0.dev20240228.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240228.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
@@ -52,31 +52,10 @@ from flwr.common.secure_aggregation.ndarrays_arithmetic import (
|
|
52
52
|
)
|
53
53
|
from flwr.common.secure_aggregation.quantization import quantize
|
54
54
|
from flwr.common.secure_aggregation.secaggplus_constants import (
|
55
|
-
KEY_ACTIVE_SECURE_ID_LIST,
|
56
|
-
KEY_CIPHERTEXT_LIST,
|
57
|
-
KEY_CLIPPING_RANGE,
|
58
|
-
KEY_DEAD_SECURE_ID_LIST,
|
59
|
-
KEY_DESTINATION_LIST,
|
60
|
-
KEY_MASKED_PARAMETERS,
|
61
|
-
KEY_MOD_RANGE,
|
62
|
-
KEY_PUBLIC_KEY_1,
|
63
|
-
KEY_PUBLIC_KEY_2,
|
64
|
-
KEY_SAMPLE_NUMBER,
|
65
|
-
KEY_SECURE_ID,
|
66
|
-
KEY_SECURE_ID_LIST,
|
67
|
-
KEY_SHARE_LIST,
|
68
|
-
KEY_SHARE_NUMBER,
|
69
|
-
KEY_SOURCE_LIST,
|
70
|
-
KEY_STAGE,
|
71
|
-
KEY_TARGET_RANGE,
|
72
|
-
KEY_THRESHOLD,
|
73
55
|
RECORD_KEY_CONFIGS,
|
74
56
|
RECORD_KEY_STATE,
|
75
|
-
|
76
|
-
|
77
|
-
STAGE_SHARE_KEYS,
|
78
|
-
STAGE_UNMASK,
|
79
|
-
STAGES,
|
57
|
+
Key,
|
58
|
+
Stage,
|
80
59
|
)
|
81
60
|
from flwr.common.secure_aggregation.secaggplus_utils import (
|
82
61
|
pseudo_rand_gen,
|
@@ -91,7 +70,7 @@ from flwr.common.typing import ConfigsRecordValues, FitRes
|
|
91
70
|
class SecAggPlusState:
|
92
71
|
"""State of the SecAgg+ protocol."""
|
93
72
|
|
94
|
-
current_stage: str =
|
73
|
+
current_stage: str = Stage.UNMASK
|
95
74
|
|
96
75
|
sid: int = 0
|
97
76
|
sample_num: int = 0
|
@@ -187,20 +166,20 @@ def secaggplus_mod(
|
|
187
166
|
check_stage(state.current_stage, configs)
|
188
167
|
|
189
168
|
# Update the current stage
|
190
|
-
state.current_stage = cast(str, configs.pop(
|
169
|
+
state.current_stage = cast(str, configs.pop(Key.STAGE))
|
191
170
|
|
192
171
|
# Check the validity of the configs based on the current stage
|
193
172
|
check_configs(state.current_stage, configs)
|
194
173
|
|
195
174
|
# Execute
|
196
|
-
if state.current_stage ==
|
175
|
+
if state.current_stage == Stage.SETUP:
|
197
176
|
res = _setup(state, configs)
|
198
|
-
elif state.current_stage ==
|
177
|
+
elif state.current_stage == Stage.SHARE_KEYS:
|
199
178
|
res = _share_keys(state, configs)
|
200
|
-
elif state.current_stage ==
|
179
|
+
elif state.current_stage == Stage.COLLECT_MASKED_INPUT:
|
201
180
|
fit = _get_fit_fn(msg, ctxt, call_next)
|
202
181
|
res = _collect_masked_input(state, configs, fit)
|
203
|
-
elif state.current_stage ==
|
182
|
+
elif state.current_stage == Stage.UNMASK:
|
204
183
|
res = _unmask(state, configs)
|
205
184
|
else:
|
206
185
|
raise ValueError(f"Unknown secagg stage: {state.current_stage}")
|
@@ -215,28 +194,29 @@ def secaggplus_mod(
|
|
215
194
|
|
216
195
|
def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
217
196
|
"""Check the validity of the next stage."""
|
218
|
-
# Check the existence of
|
219
|
-
if
|
197
|
+
# Check the existence of Config.STAGE
|
198
|
+
if Key.STAGE not in configs:
|
220
199
|
raise KeyError(
|
221
|
-
f"The required key '{
|
200
|
+
f"The required key '{Key.STAGE}' is missing from the input `named_values`."
|
222
201
|
)
|
223
202
|
|
224
|
-
# Check the value type of the
|
225
|
-
next_stage = configs[
|
203
|
+
# Check the value type of the Config.STAGE
|
204
|
+
next_stage = configs[Key.STAGE]
|
226
205
|
if not isinstance(next_stage, str):
|
227
206
|
raise TypeError(
|
228
|
-
f"The value for the key '{
|
207
|
+
f"The value for the key '{Key.STAGE}' must be of type {str}, "
|
229
208
|
f"but got {type(next_stage)} instead."
|
230
209
|
)
|
231
210
|
|
232
211
|
# Check the validity of the next stage
|
233
|
-
if next_stage ==
|
234
|
-
if current_stage !=
|
212
|
+
if next_stage == Stage.SETUP:
|
213
|
+
if current_stage != Stage.UNMASK:
|
235
214
|
log(WARNING, "Restart from the setup stage")
|
236
215
|
# If stage is not "setup",
|
237
216
|
# the stage from `named_values` should be the expected next stage
|
238
217
|
else:
|
239
|
-
|
218
|
+
stages = Stage.all()
|
219
|
+
expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
|
240
220
|
if next_stage != expected_next_stage:
|
241
221
|
raise ValueError(
|
242
222
|
"Abort secure aggregation: "
|
@@ -248,20 +228,20 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
248
228
|
def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
249
229
|
"""Check the validity of the configs."""
|
250
230
|
# Check `named_values` for the setup stage
|
251
|
-
if stage ==
|
231
|
+
if stage == Stage.SETUP:
|
252
232
|
key_type_pairs = [
|
253
|
-
(
|
254
|
-
(
|
255
|
-
(
|
256
|
-
(
|
257
|
-
(
|
258
|
-
(
|
259
|
-
(
|
233
|
+
(Key.SAMPLE_NUMBER, int),
|
234
|
+
(Key.SECURE_ID, int),
|
235
|
+
(Key.SHARE_NUMBER, int),
|
236
|
+
(Key.THRESHOLD, int),
|
237
|
+
(Key.CLIPPING_RANGE, float),
|
238
|
+
(Key.TARGET_RANGE, int),
|
239
|
+
(Key.MOD_RANGE, int),
|
260
240
|
]
|
261
241
|
for key, expected_type in key_type_pairs:
|
262
242
|
if key not in configs:
|
263
243
|
raise KeyError(
|
264
|
-
f"Stage {
|
244
|
+
f"Stage {Stage.SETUP}: the required key '{key}' is "
|
265
245
|
"missing from the input `named_values`."
|
266
246
|
)
|
267
247
|
# Bool is a subclass of int in Python,
|
@@ -269,11 +249,11 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
269
249
|
# pylint: disable-next=unidiomatic-typecheck
|
270
250
|
if type(configs[key]) is not expected_type:
|
271
251
|
raise TypeError(
|
272
|
-
f"Stage {
|
252
|
+
f"Stage {Stage.SETUP}: The value for the key '{key}' "
|
273
253
|
f"must be of type {expected_type}, "
|
274
254
|
f"but got {type(configs[key])} instead."
|
275
255
|
)
|
276
|
-
elif stage ==
|
256
|
+
elif stage == Stage.SHARE_KEYS:
|
277
257
|
for key, value in configs.items():
|
278
258
|
if (
|
279
259
|
not isinstance(value, list)
|
@@ -282,18 +262,18 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
282
262
|
or not isinstance(value[1], bytes)
|
283
263
|
):
|
284
264
|
raise TypeError(
|
285
|
-
f"Stage {
|
265
|
+
f"Stage {Stage.SHARE_KEYS}: "
|
286
266
|
f"the value for the key '{key}' must be a list of two bytes."
|
287
267
|
)
|
288
|
-
elif stage ==
|
268
|
+
elif stage == Stage.COLLECT_MASKED_INPUT:
|
289
269
|
key_type_pairs = [
|
290
|
-
(
|
291
|
-
(
|
270
|
+
(Key.CIPHERTEXT_LIST, bytes),
|
271
|
+
(Key.SOURCE_LIST, int),
|
292
272
|
]
|
293
273
|
for key, expected_type in key_type_pairs:
|
294
274
|
if key not in configs:
|
295
275
|
raise KeyError(
|
296
|
-
f"Stage {
|
276
|
+
f"Stage {Stage.COLLECT_MASKED_INPUT}: "
|
297
277
|
f"the required key '{key}' is "
|
298
278
|
"missing from the input `named_values`."
|
299
279
|
)
|
@@ -304,19 +284,19 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
304
284
|
if type(elm) is not expected_type
|
305
285
|
):
|
306
286
|
raise TypeError(
|
307
|
-
f"Stage {
|
287
|
+
f"Stage {Stage.COLLECT_MASKED_INPUT}: "
|
308
288
|
f"the value for the key '{key}' "
|
309
289
|
f"must be of type List[{expected_type.__name__}]"
|
310
290
|
)
|
311
|
-
elif stage ==
|
291
|
+
elif stage == Stage.UNMASK:
|
312
292
|
key_type_pairs = [
|
313
|
-
(
|
314
|
-
(
|
293
|
+
(Key.ACTIVE_SECURE_ID_LIST, int),
|
294
|
+
(Key.DEAD_SECURE_ID_LIST, int),
|
315
295
|
]
|
316
296
|
for key, expected_type in key_type_pairs:
|
317
297
|
if key not in configs:
|
318
298
|
raise KeyError(
|
319
|
-
f"Stage {
|
299
|
+
f"Stage {Stage.UNMASK}: "
|
320
300
|
f"the required key '{key}' is "
|
321
301
|
"missing from the input `named_values`."
|
322
302
|
)
|
@@ -327,7 +307,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
327
307
|
if type(elm) is not expected_type
|
328
308
|
):
|
329
309
|
raise TypeError(
|
330
|
-
f"Stage {
|
310
|
+
f"Stage {Stage.UNMASK}: "
|
331
311
|
f"the value for the key '{key}' "
|
332
312
|
f"must be of type List[{expected_type.__name__}]"
|
333
313
|
)
|
@@ -340,15 +320,15 @@ def _setup(
|
|
340
320
|
) -> Dict[str, ConfigsRecordValues]:
|
341
321
|
# Assigning parameter values to object fields
|
342
322
|
sec_agg_param_dict = configs
|
343
|
-
state.sample_num = cast(int, sec_agg_param_dict[
|
344
|
-
state.sid = cast(int, sec_agg_param_dict[
|
323
|
+
state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
|
324
|
+
state.sid = cast(int, sec_agg_param_dict[Key.SECURE_ID])
|
345
325
|
log(INFO, "Client %d: starting stage 0...", state.sid)
|
346
326
|
|
347
|
-
state.share_num = cast(int, sec_agg_param_dict[
|
348
|
-
state.threshold = cast(int, sec_agg_param_dict[
|
349
|
-
state.clipping_range = cast(float, sec_agg_param_dict[
|
350
|
-
state.target_range = cast(int, sec_agg_param_dict[
|
351
|
-
state.mod_range = cast(int, sec_agg_param_dict[
|
327
|
+
state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
|
328
|
+
state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
|
329
|
+
state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE])
|
330
|
+
state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE])
|
331
|
+
state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE])
|
352
332
|
|
353
333
|
# Dictionaries containing client secure IDs as keys
|
354
334
|
# and their respective secret shares as values.
|
@@ -367,7 +347,7 @@ def _setup(
|
|
367
347
|
state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
|
368
348
|
state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
|
369
349
|
log(INFO, "Client %d: stage 0 completes. uploading public keys...", state.sid)
|
370
|
-
return {
|
350
|
+
return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
|
371
351
|
|
372
352
|
|
373
353
|
# pylint: disable-next=too-many-locals
|
@@ -429,7 +409,7 @@ def _share_keys(
|
|
429
409
|
ciphertexts.append(ciphertext)
|
430
410
|
|
431
411
|
log(INFO, "Client %d: stage 1 completes. uploading key shares...", state.sid)
|
432
|
-
return {
|
412
|
+
return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
|
433
413
|
|
434
414
|
|
435
415
|
# pylint: disable-next=too-many-locals
|
@@ -440,8 +420,8 @@ def _collect_masked_input(
|
|
440
420
|
) -> Dict[str, ConfigsRecordValues]:
|
441
421
|
log(INFO, "Client %d: starting stage 2...", state.sid)
|
442
422
|
available_clients: List[int] = []
|
443
|
-
ciphertexts = cast(List[bytes], configs[
|
444
|
-
srcs = cast(List[int], configs[
|
423
|
+
ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
|
424
|
+
srcs = cast(List[int], configs[Key.SOURCE_LIST])
|
445
425
|
if len(ciphertexts) + 1 < state.threshold:
|
446
426
|
raise ValueError("Not enough available neighbour clients.")
|
447
427
|
|
@@ -505,7 +485,7 @@ def _collect_masked_input(
|
|
505
485
|
quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
|
506
486
|
log(INFO, "Client %d: stage 2 completes. uploading masked parameters...", state.sid)
|
507
487
|
return {
|
508
|
-
|
488
|
+
Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
|
509
489
|
}
|
510
490
|
|
511
491
|
|
@@ -514,8 +494,8 @@ def _unmask(
|
|
514
494
|
) -> Dict[str, ConfigsRecordValues]:
|
515
495
|
log(INFO, "Client %d: starting stage 3...", state.sid)
|
516
496
|
|
517
|
-
active_sids = cast(List[int], configs[
|
518
|
-
dead_sids = cast(List[int], configs[
|
497
|
+
active_sids = cast(List[int], configs[Key.ACTIVE_SECURE_ID_LIST])
|
498
|
+
dead_sids = cast(List[int], configs[Key.DEAD_SECURE_ID_LIST])
|
519
499
|
# Send private mask seed share for every avaliable client (including itclient)
|
520
500
|
# Send first private key share for building pairwise mask for every dropped client
|
521
501
|
if len(active_sids) < state.threshold:
|
@@ -528,4 +508,4 @@ def _unmask(
|
|
528
508
|
shares += [state.sk1_share_dict[sid] for sid in dead_sids]
|
529
509
|
|
530
510
|
log(INFO, "Client %d: stage 3 completes. uploading key shares...", state.sid)
|
531
|
-
return {
|
511
|
+
return {Key.SECURE_ID_LIST: sids, Key.SHARE_LIST: shares}
|
@@ -14,33 +14,55 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Constants for the SecAgg/SecAgg+ protocol."""
|
16
16
|
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
17
20
|
RECORD_KEY_STATE = "secaggplus_state"
|
18
21
|
RECORD_KEY_CONFIGS = "secaggplus_configs"
|
19
22
|
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
23
|
+
|
24
|
+
class Stage:
|
25
|
+
"""Stages for the SecAgg+ protocol."""
|
26
|
+
|
27
|
+
SETUP = "setup"
|
28
|
+
SHARE_KEYS = "share_keys"
|
29
|
+
COLLECT_MASKED_INPUT = "collect_masked_input"
|
30
|
+
UNMASK = "unmask"
|
31
|
+
_stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_INPUT, UNMASK)
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def all(cls) -> tuple[str, str, str, str]:
|
35
|
+
"""Return all stages."""
|
36
|
+
return cls._stages
|
37
|
+
|
38
|
+
def __new__(cls) -> Stage:
|
39
|
+
"""Prevent instantiation."""
|
40
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
41
|
+
|
42
|
+
|
43
|
+
class Key:
|
44
|
+
"""Keys for the configs in the ConfigsRecord."""
|
45
|
+
|
46
|
+
STAGE = "stage"
|
47
|
+
SAMPLE_NUMBER = "sample_num"
|
48
|
+
SECURE_ID = "secure_id"
|
49
|
+
SHARE_NUMBER = "share_num"
|
50
|
+
THRESHOLD = "threshold"
|
51
|
+
CLIPPING_RANGE = "clipping_range"
|
52
|
+
TARGET_RANGE = "target_range"
|
53
|
+
MOD_RANGE = "mod_range"
|
54
|
+
PUBLIC_KEY_1 = "pk1"
|
55
|
+
PUBLIC_KEY_2 = "pk2"
|
56
|
+
DESTINATION_LIST = "dsts"
|
57
|
+
CIPHERTEXT_LIST = "ctxts"
|
58
|
+
SOURCE_LIST = "srcs"
|
59
|
+
PARAMETERS = "params"
|
60
|
+
MASKED_PARAMETERS = "masked_params"
|
61
|
+
ACTIVE_SECURE_ID_LIST = "active_sids"
|
62
|
+
DEAD_SECURE_ID_LIST = "dead_sids"
|
63
|
+
SECURE_ID_LIST = "sids"
|
64
|
+
SHARE_LIST = "shares"
|
65
|
+
|
66
|
+
def __new__(cls) -> Key:
|
67
|
+
"""Prevent instantiation."""
|
68
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
flwr/proto/error_pb2.py
ADDED
@@ -0,0 +1,26 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3
|
+
# source: flwr/proto/error.proto
|
4
|
+
# Protobuf Python Version: 4.25.0
|
5
|
+
"""Generated protocol buffer code."""
|
6
|
+
from google.protobuf import descriptor as _descriptor
|
7
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
8
|
+
from google.protobuf import symbol_database as _symbol_database
|
9
|
+
from google.protobuf.internal import builder as _builder
|
10
|
+
# @@protoc_insertion_point(imports)
|
11
|
+
|
12
|
+
_sym_db = _symbol_database.Default()
|
13
|
+
|
14
|
+
|
15
|
+
|
16
|
+
|
17
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/error.proto\x12\nflwr.proto\"%\n\x05\x45rror\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x12\x12\x0e\n\x06reason\x18\x02 \x01(\tb\x06proto3')
|
18
|
+
|
19
|
+
_globals = globals()
|
20
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
21
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.error_pb2', _globals)
|
22
|
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
23
|
+
DESCRIPTOR._options = None
|
24
|
+
_globals['_ERROR']._serialized_start=38
|
25
|
+
_globals['_ERROR']._serialized_end=75
|
26
|
+
# @@protoc_insertion_point(module_scope)
|
flwr/proto/error_pb2.pyi
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
"""
|
2
|
+
@generated by mypy-protobuf. Do not edit manually!
|
3
|
+
isort:skip_file
|
4
|
+
"""
|
5
|
+
import builtins
|
6
|
+
import google.protobuf.descriptor
|
7
|
+
import google.protobuf.message
|
8
|
+
import typing
|
9
|
+
import typing_extensions
|
10
|
+
|
11
|
+
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
12
|
+
|
13
|
+
class Error(google.protobuf.message.Message):
|
14
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
15
|
+
CODE_FIELD_NUMBER: builtins.int
|
16
|
+
REASON_FIELD_NUMBER: builtins.int
|
17
|
+
code: builtins.int
|
18
|
+
reason: typing.Text
|
19
|
+
def __init__(self,
|
20
|
+
*,
|
21
|
+
code: builtins.int = ...,
|
22
|
+
reason: typing.Text = ...,
|
23
|
+
) -> None: ...
|
24
|
+
def ClearField(self, field_name: typing_extensions.Literal["code",b"code","reason",b"reason"]) -> None: ...
|
25
|
+
global___Error = Error
|
flwr/proto/task_pb2.py
CHANGED
@@ -15,19 +15,20 @@ _sym_db = _symbol_database.Default()
|
|
15
15
|
from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
16
16
|
from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
|
17
17
|
from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
|
18
|
+
from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2
|
18
19
|
|
19
20
|
|
20
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\
|
21
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
|
21
22
|
|
22
23
|
_globals = globals()
|
23
24
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
24
25
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals)
|
25
26
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
26
27
|
DESCRIPTOR._options = None
|
27
|
-
_globals['_TASK']._serialized_start=
|
28
|
-
_globals['_TASK']._serialized_end=
|
29
|
-
_globals['_TASKINS']._serialized_start=
|
30
|
-
_globals['_TASKINS']._serialized_end=
|
31
|
-
_globals['_TASKRES']._serialized_start=
|
32
|
-
_globals['_TASKRES']._serialized_end=
|
28
|
+
_globals['_TASK']._serialized_start=141
|
29
|
+
_globals['_TASK']._serialized_end=387
|
30
|
+
_globals['_TASKINS']._serialized_start=389
|
31
|
+
_globals['_TASKINS']._serialized_end=481
|
32
|
+
_globals['_TASKRES']._serialized_start=483
|
33
|
+
_globals['_TASKRES']._serialized_end=575
|
33
34
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/task_pb2.pyi
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
isort:skip_file
|
4
4
|
"""
|
5
5
|
import builtins
|
6
|
+
import flwr.proto.error_pb2
|
6
7
|
import flwr.proto.node_pb2
|
7
8
|
import flwr.proto.recordset_pb2
|
8
9
|
import google.protobuf.descriptor
|
@@ -23,6 +24,7 @@ class Task(google.protobuf.message.Message):
|
|
23
24
|
ANCESTRY_FIELD_NUMBER: builtins.int
|
24
25
|
TASK_TYPE_FIELD_NUMBER: builtins.int
|
25
26
|
RECORDSET_FIELD_NUMBER: builtins.int
|
27
|
+
ERROR_FIELD_NUMBER: builtins.int
|
26
28
|
@property
|
27
29
|
def producer(self) -> flwr.proto.node_pb2.Node: ...
|
28
30
|
@property
|
@@ -35,6 +37,8 @@ class Task(google.protobuf.message.Message):
|
|
35
37
|
task_type: typing.Text
|
36
38
|
@property
|
37
39
|
def recordset(self) -> flwr.proto.recordset_pb2.RecordSet: ...
|
40
|
+
@property
|
41
|
+
def error(self) -> flwr.proto.error_pb2.Error: ...
|
38
42
|
def __init__(self,
|
39
43
|
*,
|
40
44
|
producer: typing.Optional[flwr.proto.node_pb2.Node] = ...,
|
@@ -45,9 +49,10 @@ class Task(google.protobuf.message.Message):
|
|
45
49
|
ancestry: typing.Optional[typing.Iterable[typing.Text]] = ...,
|
46
50
|
task_type: typing.Text = ...,
|
47
51
|
recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ...,
|
52
|
+
error: typing.Optional[flwr.proto.error_pb2.Error] = ...,
|
48
53
|
) -> None: ...
|
49
|
-
def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
|
50
|
-
def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
|
54
|
+
def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","error",b"error","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
|
55
|
+
def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
|
51
56
|
global___Task = Task
|
52
57
|
|
53
58
|
class TaskIns(google.protobuf.message.Message):
|
flwr/server/strategy/__init__.py
CHANGED
@@ -16,10 +16,17 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from .bulyan import Bulyan as Bulyan
|
19
|
-
from .dp_adaptive_clipping import
|
19
|
+
from .dp_adaptive_clipping import (
|
20
|
+
DifferentialPrivacyClientSideAdaptiveClipping as DifferentialPrivacyClientSideAdaptiveClipping,
|
21
|
+
)
|
22
|
+
from .dp_adaptive_clipping import (
|
23
|
+
DifferentialPrivacyServerSideAdaptiveClipping as DifferentialPrivacyServerSideAdaptiveClipping,
|
24
|
+
)
|
25
|
+
from .dp_fixed_clipping import (
|
26
|
+
DifferentialPrivacyClientSideFixedClipping as DifferentialPrivacyClientSideFixedClipping,
|
27
|
+
)
|
20
28
|
from .dp_fixed_clipping import (
|
21
|
-
|
22
|
-
DifferentialPrivacyServerSideFixedClipping,
|
29
|
+
DifferentialPrivacyServerSideFixedClipping as DifferentialPrivacyServerSideFixedClipping,
|
23
30
|
)
|
24
31
|
from .dpfedavg_adaptive import DPFedAvgAdaptive as DPFedAvgAdaptive
|
25
32
|
from .dpfedavg_fixed import DPFedAvgFixed as DPFedAvgFixed
|
@@ -46,6 +53,7 @@ __all__ = [
|
|
46
53
|
"DPFedAvgAdaptive",
|
47
54
|
"DPFedAvgFixed",
|
48
55
|
"DifferentialPrivacyClientSideAdaptiveClipping",
|
56
|
+
"DifferentialPrivacyServerSideAdaptiveClipping",
|
49
57
|
"DifferentialPrivacyClientSideFixedClipping",
|
50
58
|
"DifferentialPrivacyServerSideFixedClipping",
|
51
59
|
"FedAdagrad",
|
@@ -24,8 +24,19 @@ from typing import Dict, List, Optional, Tuple, Union
|
|
24
24
|
|
25
25
|
import numpy as np
|
26
26
|
|
27
|
-
from flwr.common import
|
27
|
+
from flwr.common import (
|
28
|
+
EvaluateIns,
|
29
|
+
EvaluateRes,
|
30
|
+
FitIns,
|
31
|
+
FitRes,
|
32
|
+
NDArrays,
|
33
|
+
Parameters,
|
34
|
+
Scalar,
|
35
|
+
ndarrays_to_parameters,
|
36
|
+
parameters_to_ndarrays,
|
37
|
+
)
|
28
38
|
from flwr.common.differential_privacy import (
|
39
|
+
adaptive_clip_inputs_inplace,
|
29
40
|
add_gaussian_noise_to_params,
|
30
41
|
compute_adaptive_noise_params,
|
31
42
|
)
|
@@ -40,6 +51,199 @@ from flwr.server.client_proxy import ClientProxy
|
|
40
51
|
from flwr.server.strategy.strategy import Strategy
|
41
52
|
|
42
53
|
|
54
|
+
class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
55
|
+
"""Strategy wrapper for central DP with server-side adaptive clipping.
|
56
|
+
|
57
|
+
Parameters
|
58
|
+
----------
|
59
|
+
strategy: Strategy
|
60
|
+
The strategy to which DP functionalities will be added by this wrapper.
|
61
|
+
noise_multiplier : float
|
62
|
+
The noise multiplier for the Gaussian mechanism for model updates.
|
63
|
+
num_sampled_clients : int
|
64
|
+
The number of clients that are sampled on each round.
|
65
|
+
initial_clipping_norm : float
|
66
|
+
The initial value of clipping norm. Deafults to 0.1.
|
67
|
+
Andrew et al. recommends to set to 0.1.
|
68
|
+
target_clipped_quantile : float
|
69
|
+
The desired quantile of updates which should be clipped. Defaults to 0.5.
|
70
|
+
clip_norm_lr : float
|
71
|
+
The learning rate for the clipping norm adaptation. Defaults to 0.2.
|
72
|
+
Andrew et al. recommends to set to 0.2.
|
73
|
+
clipped_count_stddev : float
|
74
|
+
The standard deviation of the noise added to the count of updates below the estimate.
|
75
|
+
Andrew et al. recommends to set to `expected_num_records/20`
|
76
|
+
|
77
|
+
Examples
|
78
|
+
--------
|
79
|
+
Create a strategy:
|
80
|
+
|
81
|
+
>>> strategy = fl.server.strategy.FedAvg( ... )
|
82
|
+
|
83
|
+
Wrap the strategy with the DifferentialPrivacyServerSideAdaptiveClipping wrapper
|
84
|
+
|
85
|
+
>>> dp_strategy = DifferentialPrivacyServerSideAdaptiveClipping(
|
86
|
+
>>> strategy, cfg.noise_multiplier, cfg.num_sampled_clients, ...
|
87
|
+
>>> )
|
88
|
+
"""
|
89
|
+
|
90
|
+
# pylint: disable=too-many-arguments,too-many-instance-attributes
|
91
|
+
def __init__(
|
92
|
+
self,
|
93
|
+
strategy: Strategy,
|
94
|
+
noise_multiplier: float,
|
95
|
+
num_sampled_clients: int,
|
96
|
+
initial_clipping_norm: float = 0.1,
|
97
|
+
target_clipped_quantile: float = 0.5,
|
98
|
+
clip_norm_lr: float = 0.2,
|
99
|
+
clipped_count_stddev: Optional[float] = None,
|
100
|
+
) -> None:
|
101
|
+
super().__init__()
|
102
|
+
|
103
|
+
if strategy is None:
|
104
|
+
raise ValueError("The passed strategy is None.")
|
105
|
+
|
106
|
+
if noise_multiplier < 0:
|
107
|
+
raise ValueError("The noise multiplier should be a non-negative value.")
|
108
|
+
|
109
|
+
if num_sampled_clients <= 0:
|
110
|
+
raise ValueError(
|
111
|
+
"The number of sampled clients should be a positive value."
|
112
|
+
)
|
113
|
+
|
114
|
+
if initial_clipping_norm <= 0:
|
115
|
+
raise ValueError("The initial clipping norm should be a positive value.")
|
116
|
+
|
117
|
+
if not 0 <= target_clipped_quantile <= 1:
|
118
|
+
raise ValueError(
|
119
|
+
"The target clipped quantile must be between 0 and 1 (inclusive)."
|
120
|
+
)
|
121
|
+
|
122
|
+
if clip_norm_lr <= 0:
|
123
|
+
raise ValueError("The learning rate must be positive.")
|
124
|
+
|
125
|
+
if clipped_count_stddev is not None:
|
126
|
+
if clipped_count_stddev < 0:
|
127
|
+
raise ValueError("The `clipped_count_stddev` must be non-negative.")
|
128
|
+
|
129
|
+
self.strategy = strategy
|
130
|
+
self.num_sampled_clients = num_sampled_clients
|
131
|
+
self.clipping_norm = initial_clipping_norm
|
132
|
+
self.target_clipped_quantile = target_clipped_quantile
|
133
|
+
self.clip_norm_lr = clip_norm_lr
|
134
|
+
(
|
135
|
+
self.clipped_count_stddev,
|
136
|
+
self.noise_multiplier,
|
137
|
+
) = compute_adaptive_noise_params(
|
138
|
+
noise_multiplier,
|
139
|
+
num_sampled_clients,
|
140
|
+
clipped_count_stddev,
|
141
|
+
)
|
142
|
+
|
143
|
+
self.current_round_params: NDArrays = []
|
144
|
+
|
145
|
+
def __repr__(self) -> str:
|
146
|
+
"""Compute a string representation of the strategy."""
|
147
|
+
rep = "Differential Privacy Strategy Wrapper (Server-Side Adaptive Clipping)"
|
148
|
+
return rep
|
149
|
+
|
150
|
+
def initialize_parameters(
|
151
|
+
self, client_manager: ClientManager
|
152
|
+
) -> Optional[Parameters]:
|
153
|
+
"""Initialize global model parameters using given strategy."""
|
154
|
+
return self.strategy.initialize_parameters(client_manager)
|
155
|
+
|
156
|
+
def configure_fit(
|
157
|
+
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
158
|
+
) -> List[Tuple[ClientProxy, FitIns]]:
|
159
|
+
"""Configure the next round of training."""
|
160
|
+
self.current_round_params = parameters_to_ndarrays(parameters)
|
161
|
+
return self.strategy.configure_fit(server_round, parameters, client_manager)
|
162
|
+
|
163
|
+
def configure_evaluate(
|
164
|
+
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
165
|
+
) -> List[Tuple[ClientProxy, EvaluateIns]]:
|
166
|
+
"""Configure the next round of evaluation."""
|
167
|
+
return self.strategy.configure_evaluate(
|
168
|
+
server_round, parameters, client_manager
|
169
|
+
)
|
170
|
+
|
171
|
+
def aggregate_fit(
|
172
|
+
self,
|
173
|
+
server_round: int,
|
174
|
+
results: List[Tuple[ClientProxy, FitRes]],
|
175
|
+
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
|
176
|
+
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
|
177
|
+
"""Aggregate training results and update clip norms."""
|
178
|
+
if failures:
|
179
|
+
return None, {}
|
180
|
+
|
181
|
+
if len(results) != self.num_sampled_clients:
|
182
|
+
log(
|
183
|
+
WARNING,
|
184
|
+
CLIENTS_DISCREPANCY_WARNING,
|
185
|
+
len(results),
|
186
|
+
self.num_sampled_clients,
|
187
|
+
)
|
188
|
+
|
189
|
+
norm_bit_set_count = 0
|
190
|
+
for _, res in results:
|
191
|
+
param = parameters_to_ndarrays(res.parameters)
|
192
|
+
# Compute and clip update
|
193
|
+
model_update = [
|
194
|
+
np.subtract(x, y) for (x, y) in zip(param, self.current_round_params)
|
195
|
+
]
|
196
|
+
|
197
|
+
norm_bit = adaptive_clip_inputs_inplace(model_update, self.clipping_norm)
|
198
|
+
norm_bit_set_count += norm_bit
|
199
|
+
|
200
|
+
for i, _ in enumerate(self.current_round_params):
|
201
|
+
param[i] = self.current_round_params[i] + model_update[i]
|
202
|
+
# Convert back to parameters
|
203
|
+
res.parameters = ndarrays_to_parameters(param)
|
204
|
+
|
205
|
+
# Noising the count
|
206
|
+
noised_norm_bit_set_count = float(
|
207
|
+
np.random.normal(norm_bit_set_count, self.clipped_count_stddev)
|
208
|
+
)
|
209
|
+
noised_norm_bit_set_fraction = noised_norm_bit_set_count / len(results)
|
210
|
+
# Geometric update
|
211
|
+
self.clipping_norm *= math.exp(
|
212
|
+
-self.clip_norm_lr
|
213
|
+
* (noised_norm_bit_set_fraction - self.target_clipped_quantile)
|
214
|
+
)
|
215
|
+
|
216
|
+
aggregated_params, metrics = self.strategy.aggregate_fit(
|
217
|
+
server_round, results, failures
|
218
|
+
)
|
219
|
+
|
220
|
+
# Add Gaussian noise to the aggregated parameters
|
221
|
+
if aggregated_params:
|
222
|
+
aggregated_params = add_gaussian_noise_to_params(
|
223
|
+
aggregated_params,
|
224
|
+
self.noise_multiplier,
|
225
|
+
self.clipping_norm,
|
226
|
+
self.num_sampled_clients,
|
227
|
+
)
|
228
|
+
|
229
|
+
return aggregated_params, metrics
|
230
|
+
|
231
|
+
def aggregate_evaluate(
|
232
|
+
self,
|
233
|
+
server_round: int,
|
234
|
+
results: List[Tuple[ClientProxy, EvaluateRes]],
|
235
|
+
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
|
236
|
+
) -> Tuple[Optional[float], Dict[str, Scalar]]:
|
237
|
+
"""Aggregate evaluation losses using the given strategy."""
|
238
|
+
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
239
|
+
|
240
|
+
def evaluate(
|
241
|
+
self, server_round: int, parameters: Parameters
|
242
|
+
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
|
243
|
+
"""Evaluate model parameters using an evaluation function from the strategy."""
|
244
|
+
return self.strategy.evaluate(server_round, parameters)
|
245
|
+
|
246
|
+
|
43
247
|
class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
44
248
|
"""Strategy wrapper for central DP with client-side adaptive clipping.
|
45
249
|
|
@@ -15,7 +15,7 @@
|
|
15
15
|
"""Driver API servicer."""
|
16
16
|
|
17
17
|
|
18
|
-
from logging import INFO
|
18
|
+
from logging import DEBUG, INFO
|
19
19
|
from typing import List, Optional, Set
|
20
20
|
from uuid import UUID
|
21
21
|
|
@@ -70,7 +70,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
70
70
|
self, request: PushTaskInsRequest, context: grpc.ServicerContext
|
71
71
|
) -> PushTaskInsResponse:
|
72
72
|
"""Push a set of TaskIns."""
|
73
|
-
log(
|
73
|
+
log(DEBUG, "DriverServicer.PushTaskIns")
|
74
74
|
|
75
75
|
# Validate request
|
76
76
|
_raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
|
@@ -95,7 +95,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
95
95
|
self, request: PullTaskResRequest, context: grpc.ServicerContext
|
96
96
|
) -> PullTaskResResponse:
|
97
97
|
"""Pull a set of TaskRes."""
|
98
|
-
log(
|
98
|
+
log(DEBUG, "DriverServicer.PullTaskRes")
|
99
99
|
|
100
100
|
# Convert each task_id str to UUID
|
101
101
|
task_ids: Set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
@@ -105,7 +105,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
105
105
|
|
106
106
|
# Register callback
|
107
107
|
def on_rpc_done() -> None:
|
108
|
-
log(
|
108
|
+
log(DEBUG, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
|
109
109
|
|
110
110
|
if context.is_active():
|
111
111
|
return
|
flwr/simulation/__init__.py
CHANGED
@@ -17,6 +17,8 @@
|
|
17
17
|
|
18
18
|
import importlib
|
19
19
|
|
20
|
+
from flwr.simulation.run_simulation import run_simulation
|
21
|
+
|
20
22
|
is_ray_installed = importlib.util.find_spec("ray") is not None
|
21
23
|
|
22
24
|
if is_ray_installed:
|
@@ -36,4 +38,5 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
|
|
36
38
|
|
37
39
|
__all__ = [
|
38
40
|
"start_simulation",
|
41
|
+
"run_simulation",
|
39
42
|
]
|
@@ -0,0 +1,177 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Flower Simulation."""
|
16
|
+
|
17
|
+
import argparse
|
18
|
+
import asyncio
|
19
|
+
import json
|
20
|
+
import threading
|
21
|
+
import traceback
|
22
|
+
from logging import ERROR, INFO, WARNING
|
23
|
+
|
24
|
+
import grpc
|
25
|
+
|
26
|
+
from flwr.common import EventType, event, log
|
27
|
+
from flwr.common.exit_handlers import register_exit_handlers
|
28
|
+
from flwr.server.driver.driver import Driver
|
29
|
+
from flwr.server.run_serverapp import run
|
30
|
+
from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
|
31
|
+
from flwr.server.superlink.fleet import vce
|
32
|
+
from flwr.server.superlink.state import StateFactory
|
33
|
+
from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
|
34
|
+
|
35
|
+
|
36
|
+
def run_simulation() -> None:
|
37
|
+
"""Run Simulation Engine."""
|
38
|
+
args = _parse_args_run_simulation().parse_args()
|
39
|
+
|
40
|
+
# Load JSON config
|
41
|
+
backend_config_dict = json.loads(args.backend_config)
|
42
|
+
|
43
|
+
# Enable GPU memory growth (relevant only for TF)
|
44
|
+
if args.enable_tf_gpu_growth:
|
45
|
+
log(INFO, "Enabling GPU growth for Tensorflow on the main thread.")
|
46
|
+
enable_tf_gpu_growth()
|
47
|
+
# Check that Backend config has also enabled using GPU growth
|
48
|
+
use_tf = backend_config_dict.get("tensorflow", False)
|
49
|
+
if not use_tf:
|
50
|
+
log(WARNING, "Enabling GPU growth for your backend.")
|
51
|
+
backend_config_dict["tensorflow"] = True
|
52
|
+
|
53
|
+
# Convert back to JSON stream
|
54
|
+
backend_config = json.dumps(backend_config_dict)
|
55
|
+
|
56
|
+
# Initialize StateFactory
|
57
|
+
state_factory = StateFactory(":flwr-in-memory-state:")
|
58
|
+
|
59
|
+
# Start Driver API
|
60
|
+
driver_server: grpc.Server = run_driver_api_grpc(
|
61
|
+
address=args.driver_api_address,
|
62
|
+
state_factory=state_factory,
|
63
|
+
certificates=None,
|
64
|
+
)
|
65
|
+
|
66
|
+
# SuperLink with Simulation Engine
|
67
|
+
f_stop = asyncio.Event()
|
68
|
+
superlink_th = threading.Thread(
|
69
|
+
target=vce.start_vce,
|
70
|
+
kwargs={
|
71
|
+
"num_supernodes": args.num_supernodes,
|
72
|
+
"client_app_module_name": args.client_app,
|
73
|
+
"backend_name": args.backend,
|
74
|
+
"backend_config_json_stream": backend_config,
|
75
|
+
"working_dir": args.dir,
|
76
|
+
"state_factory": state_factory,
|
77
|
+
"f_stop": f_stop,
|
78
|
+
},
|
79
|
+
daemon=False,
|
80
|
+
)
|
81
|
+
|
82
|
+
superlink_th.start()
|
83
|
+
event(EventType.RUN_SUPERLINK_ENTER)
|
84
|
+
|
85
|
+
try:
|
86
|
+
# Initialize Driver
|
87
|
+
driver = Driver(
|
88
|
+
driver_service_address=args.driver_api_address,
|
89
|
+
root_certificates=None,
|
90
|
+
)
|
91
|
+
|
92
|
+
# Launch server app
|
93
|
+
run(args.server_app, driver, args.dir)
|
94
|
+
|
95
|
+
except Exception as ex:
|
96
|
+
|
97
|
+
log(ERROR, "An exception occurred: %s", ex)
|
98
|
+
log(ERROR, traceback.format_exc())
|
99
|
+
raise RuntimeError(
|
100
|
+
"An error was encountered by the Simulation Engine. Ending simulation."
|
101
|
+
) from ex
|
102
|
+
|
103
|
+
finally:
|
104
|
+
|
105
|
+
del driver
|
106
|
+
|
107
|
+
# Trigger stop event
|
108
|
+
f_stop.set()
|
109
|
+
|
110
|
+
register_exit_handlers(
|
111
|
+
grpc_servers=[driver_server],
|
112
|
+
bckg_threads=[superlink_th],
|
113
|
+
event_type=EventType.RUN_SUPERLINK_LEAVE,
|
114
|
+
)
|
115
|
+
superlink_th.join()
|
116
|
+
|
117
|
+
|
118
|
+
def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
119
|
+
"""Parse flower-simulation command line arguments."""
|
120
|
+
parser = argparse.ArgumentParser(
|
121
|
+
description="Start a Flower simulation",
|
122
|
+
)
|
123
|
+
parser.add_argument(
|
124
|
+
"--client-app",
|
125
|
+
required=True,
|
126
|
+
help="For example: `client:app` or `project.package.module:wrapper.app`",
|
127
|
+
)
|
128
|
+
parser.add_argument(
|
129
|
+
"--server-app",
|
130
|
+
required=True,
|
131
|
+
help="For example: `server:app` or `project.package.module:wrapper.app`",
|
132
|
+
)
|
133
|
+
parser.add_argument(
|
134
|
+
"--driver-api-address",
|
135
|
+
default="0.0.0.0:9091",
|
136
|
+
type=str,
|
137
|
+
help="For example: `server:app` or `project.package.module:wrapper.app`",
|
138
|
+
)
|
139
|
+
parser.add_argument(
|
140
|
+
"--num-supernodes",
|
141
|
+
type=int,
|
142
|
+
required=True,
|
143
|
+
help="Number of simulated SuperNodes.",
|
144
|
+
)
|
145
|
+
parser.add_argument(
|
146
|
+
"--backend",
|
147
|
+
default="ray",
|
148
|
+
type=str,
|
149
|
+
help="Simulation backend that executes the ClientApp.",
|
150
|
+
)
|
151
|
+
parser.add_argument(
|
152
|
+
"--enable-tf-gpu-growth",
|
153
|
+
action="store_true",
|
154
|
+
help="Enables GPU growth on the main thread. This is desirable if you make "
|
155
|
+
"use of a TensorFlow model on your `ServerApp` while having your `ClientApp` "
|
156
|
+
"running on the same GPU. Without enabling this, you might encounter an "
|
157
|
+
"out-of-memory error becasue TensorFlow by default allocates all GPU memory."
|
158
|
+
"Read mor about how `tf.config.experimental.set_memory_growth()` works in "
|
159
|
+
"the TensorFlow documentation: https://www.tensorflow.org/api/stable.",
|
160
|
+
)
|
161
|
+
parser.add_argument(
|
162
|
+
"--backend-config",
|
163
|
+
type=str,
|
164
|
+
default='{"client_resources": {"num_cpus":2, "num_gpus":0.0}, "tensorflow": 0}',
|
165
|
+
help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
|
166
|
+
"configure a backend. Values supported in <value> are those included by "
|
167
|
+
"`flwr.common.typing.ConfigsRecordValues`. ",
|
168
|
+
)
|
169
|
+
parser.add_argument(
|
170
|
+
"--dir",
|
171
|
+
default="",
|
172
|
+
help="Add specified directory to the PYTHONPATH and load"
|
173
|
+
"ClientApp and ServerApp from there."
|
174
|
+
" Default: current working directory.",
|
175
|
+
)
|
176
|
+
|
177
|
+
return parser
|
{flwr_nightly-1.8.0.dev20240228.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD
RENAMED
@@ -32,7 +32,7 @@ flwr/client/message_handler/task_handler.py,sha256=ZDJBKmrn2grRMNl1rU1iGs7FiMHL5
|
|
32
32
|
flwr/client/mod/__init__.py,sha256=w6r7n6fWIrrm4lEk36lh9f1Ix6LXgAzQUrgjmMspY98,961
|
33
33
|
flwr/client/mod/centraldp_mods.py,sha256=aHbzGjSbyRENuU5vzad_tkJ9UDb48uHEvUq-zgydBwo,4954
|
34
34
|
flwr/client/mod/secure_aggregation/__init__.py,sha256=AzCdezuzX2BfXUuxVRwXdv8-zUIXoU-Bf6u4LRhzvg8,796
|
35
|
-
flwr/client/mod/secure_aggregation/secaggplus_mod.py,sha256=
|
35
|
+
flwr/client/mod/secure_aggregation/secaggplus_mod.py,sha256=Zc5b2C58SYeyQVXIbLRESOIq4rMUOuzMHNAjdSAPt6I,19434
|
36
36
|
flwr/client/mod/utils.py,sha256=lvETHcCYsSWz7h8I772hCV_kZspxqlMqzriMZ-SxmKc,1226
|
37
37
|
flwr/client/node_state.py,sha256=KTTs_l4I0jBM7IsSsbAGjhfL_yZC3QANbzyvyfZBRDM,1778
|
38
38
|
flwr/client/node_state_tests.py,sha256=gPwz0zf2iuDSa11jedkur_u3Xm7lokIDG5ALD2MCvSw,2195
|
@@ -68,7 +68,7 @@ flwr/common/secure_aggregation/crypto/shamir.py,sha256=yY35ZgHlB4YyGW_buG-1X-0M-
|
|
68
68
|
flwr/common/secure_aggregation/crypto/symmetric_encryption.py,sha256=-zDyQoTsHHQjR7o-92FNIikg1zM_Ke9yynaD5u2BXbQ,3546
|
69
69
|
flwr/common/secure_aggregation/ndarrays_arithmetic.py,sha256=KAHCEHGSTJ6mCgnC8dTpIx6URk11s5XxWTHI_7ToGIg,2979
|
70
70
|
flwr/common/secure_aggregation/quantization.py,sha256=appui7GGrkRPsupF59TkapeV4Na_CyPi73JtJ1pimdI,2310
|
71
|
-
flwr/common/secure_aggregation/secaggplus_constants.py,sha256=
|
71
|
+
flwr/common/secure_aggregation/secaggplus_constants.py,sha256=2dYMKqiO2ja8DewFeZUAGqM0xujNfyYHVwShRxeRaIM,2132
|
72
72
|
flwr/common/secure_aggregation/secaggplus_utils.py,sha256=PleDyDu7jHNAfbRoEaoQiOjxG6iMl9yA8rNKYTfnyFw,3155
|
73
73
|
flwr/common/serde.py,sha256=0tmfTcywJVLA7Hsu4nAjMn2dVMVbjYZqePJbPcDt01Y,20726
|
74
74
|
flwr/common/telemetry.py,sha256=JkFB6WBOskqAJfzSM-l6tQfRiSi2oiysClfg0-5T7NY,7782
|
@@ -79,6 +79,10 @@ flwr/proto/driver_pb2.py,sha256=JHIdjNPTgp6YHD-_lz5ZZFB0VIOR3_GmcaOTN4jndc4,3115
|
|
79
79
|
flwr/proto/driver_pb2.pyi,sha256=xwl2AqIWn0SwAlg-x5RUQeqr6DC48eywnqmD7gbaaFs,4670
|
80
80
|
flwr/proto/driver_pb2_grpc.py,sha256=qQBRdQUz4k2K4DVO7kSfWHx-62UJ85HaYKnKCr6JcU8,7304
|
81
81
|
flwr/proto/driver_pb2_grpc.pyi,sha256=NpOM5eCrIPcuWdYrZAayQSDvvFp6cDCVflabhmuvMfo,2022
|
82
|
+
flwr/proto/error_pb2.py,sha256=LarjKL90LbwkXKlhzNrDssgl4DXcvIPve8NVCXHpsKA,1084
|
83
|
+
flwr/proto/error_pb2.pyi,sha256=ZNH4HhJTU_KfMXlyCeg8FwU-fcUYxTqEmoJPtWtHikc,734
|
84
|
+
flwr/proto/error_pb2_grpc.py,sha256=1oboBPFxaTEXt9Aw7EAj8gXHDCNMhZD2VXqocC9l_gk,159
|
85
|
+
flwr/proto/error_pb2_grpc.pyi,sha256=ff2TSiLVnG6IVQcTGzb2DIH3XRSoAvAo_RMcvbMFyc0,76
|
82
86
|
flwr/proto/fleet_pb2.py,sha256=8rKQHu6Oa9ki_NG6kRNGtfPPYZp5kKBZhPW696_kn84,3852
|
83
87
|
flwr/proto/fleet_pb2.pyi,sha256=QXYs9M7_dABghdCMfk5Rjf4w0LsZGDeQ1ojH00XaQME,6182
|
84
88
|
flwr/proto/fleet_pb2_grpc.py,sha256=hF1uPaioZzQMRCP9yPlv9LC0mi_DTuhn-IkQJzWIPCs,7505
|
@@ -91,8 +95,8 @@ flwr/proto/recordset_pb2.py,sha256=un8L0kvBcgFXQIiQweOseeIJBjlOozUvQY9uTQ42Dqo,6
|
|
91
95
|
flwr/proto/recordset_pb2.pyi,sha256=NPzCJWAj1xLWzeZ_xZ6uaObQjQfWGnnqlLtn4J-SoFY,14161
|
92
96
|
flwr/proto/recordset_pb2_grpc.py,sha256=1oboBPFxaTEXt9Aw7EAj8gXHDCNMhZD2VXqocC9l_gk,159
|
93
97
|
flwr/proto/recordset_pb2_grpc.pyi,sha256=ff2TSiLVnG6IVQcTGzb2DIH3XRSoAvAo_RMcvbMFyc0,76
|
94
|
-
flwr/proto/task_pb2.py,sha256
|
95
|
-
flwr/proto/task_pb2.pyi,sha256=
|
98
|
+
flwr/proto/task_pb2.py,sha256=-UX3TqskOIRbPu8U3YwgW9ul2k9ZN6MJGgbIOX3pTqg,2431
|
99
|
+
flwr/proto/task_pb2.pyi,sha256=IgXggFya0RpL64z2o2K_qLnZHyZ1mg_WzLxFwEKrpL0,4171
|
96
100
|
flwr/proto/task_pb2_grpc.py,sha256=1oboBPFxaTEXt9Aw7EAj8gXHDCNMhZD2VXqocC9l_gk,159
|
97
101
|
flwr/proto/task_pb2_grpc.pyi,sha256=ff2TSiLVnG6IVQcTGzb2DIH3XRSoAvAo_RMcvbMFyc0,76
|
98
102
|
flwr/proto/transport_pb2.py,sha256=cURzfpCgZvH7GEvBPLvTYijE3HvhK1MePjINk4xYArk,9781
|
@@ -118,10 +122,10 @@ flwr/server/run_serverapp.py,sha256=7LLE1cVQz0Rl-hZnY7DLXvFxWCdep8xgLgEVC-yffi0,
|
|
118
122
|
flwr/server/server.py,sha256=JLc2lg-qCchD-Jyg_hTBZQN3rXsnfAGx6qJAo0vqH2Y,17812
|
119
123
|
flwr/server/server_app.py,sha256=avNQ7AMMKsn09ly81C3UBgOfHhM_R29l4MrzlalGoj8,5892
|
120
124
|
flwr/server/server_config.py,sha256=yOHpkdyuhOm--Gy_4Vofvu6jCDxhyECEDpIy02beuCg,1018
|
121
|
-
flwr/server/strategy/__init__.py,sha256=
|
125
|
+
flwr/server/strategy/__init__.py,sha256=7eVZ3hQEg2BgA_usAeL6tsLp9T6XI1VYYoFy08Xn-ew,2836
|
122
126
|
flwr/server/strategy/aggregate.py,sha256=QyRIJtI5gnuY1NbgrcrOvkHxGIxBvApq7d9Y4xl-6W4,13468
|
123
127
|
flwr/server/strategy/bulyan.py,sha256=8GsSVJzRSoSWE2zQUKqC3Z795grdN9xpmc3MSGGXnzM,6532
|
124
|
-
flwr/server/strategy/dp_adaptive_clipping.py,sha256=
|
128
|
+
flwr/server/strategy/dp_adaptive_clipping.py,sha256=BVvX1LivyukvEtOZVZnpgVpzH8BBjvA3OmdGwFxgRuQ,16679
|
125
129
|
flwr/server/strategy/dp_fixed_clipping.py,sha256=v9YyX53jt2RatGnFxTK4ZMO_3SN7EdL9YCaaJtn9Fcc,12125
|
126
130
|
flwr/server/strategy/dpfedavg_adaptive.py,sha256=hLJkPQJl1bHjwrBNg3PSRFKf3no0hg5EHiFaWhHlWqw,4877
|
127
131
|
flwr/server/strategy/dpfedavg_fixed.py,sha256=G0yYxrPoM-MHQ889DYN3OeNiEeU0yQrjgAzcq0G653w,7219
|
@@ -145,7 +149,7 @@ flwr/server/strategy/strategy.py,sha256=g6VoIFogEviRub6G4QsKdIp6M_Ek6GhBhqcdNx5u
|
|
145
149
|
flwr/server/superlink/__init__.py,sha256=8tHYCfodUlRD8PCP9fHgvu8cz5N31A2QoRVL0jDJ15E,707
|
146
150
|
flwr/server/superlink/driver/__init__.py,sha256=STB1_DASVEg7Cu6L7VYxTzV7UMkgtBkFim09Z82Dh8I,712
|
147
151
|
flwr/server/superlink/driver/driver_grpc.py,sha256=1qSGDs1k_OVPWxp2ofxvQgtYXExrMeC3N_rNPVWH65M,1932
|
148
|
-
flwr/server/superlink/driver/driver_servicer.py,sha256=
|
152
|
+
flwr/server/superlink/driver/driver_servicer.py,sha256=p1rlocgqU2I4s5IwdU8rTZBkQ73yPmuWtK_aUlB7V84,4573
|
149
153
|
flwr/server/superlink/fleet/__init__.py,sha256=C6GCSD5eP5Of6_dIeSe1jx9HnV0icsvWyQ5EKAUHJRU,711
|
150
154
|
flwr/server/superlink/fleet/grpc_bidi/__init__.py,sha256=mgGJGjwT6VU7ovC1gdnnqttjyBPlNIcZnYRqx4K3IBQ,735
|
151
155
|
flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py,sha256=57b3UL5-baGdLwgCtB0dCUTTSbmmfMAXcXV5bjPZNWQ,5993
|
@@ -174,14 +178,15 @@ flwr/server/utils/tensorboard.py,sha256=k0G6bqsLx7wfYbH2KtXsDYcOCfyIeE12-hefXA7l
|
|
174
178
|
flwr/server/utils/validator.py,sha256=IJN2475yyD_i_9kg_SJ_JodIuZh58ufpWGUDQRAqu2s,4740
|
175
179
|
flwr/server/workflow/__init__.py,sha256=2YrKq5wUwge8Tm1xaAdf2P3l4LbM4olka6tO_0_Mu9A,787
|
176
180
|
flwr/server/workflow/default_workflows.py,sha256=DKSt14WY5m19ujwh6UDP4a31kRs-j6V_NZm5cXn01ZY,12705
|
177
|
-
flwr/simulation/__init__.py,sha256=
|
181
|
+
flwr/simulation/__init__.py,sha256=jdrJeTnLLj9Eyl8BRPXMewqkhTnxD7fvXDgyjfspy0Q,1359
|
178
182
|
flwr/simulation/app.py,sha256=WqJxdXTEuehwMW605p5NMmvBbKYx5tuqnV3Mp7jSWXM,13904
|
179
183
|
flwr/simulation/ray_transport/__init__.py,sha256=FsaAnzC4cw4DqoouBCix6496k29jACkfeIam55BvW9g,734
|
180
184
|
flwr/simulation/ray_transport/ray_actor.py,sha256=zRETW_xuCAOLRFaYnQ-q3IBSz0LIv_0RifGuhgWaYOg,19872
|
181
185
|
flwr/simulation/ray_transport/ray_client_proxy.py,sha256=DpmrBC87_sX3J4WrrwzyEDIjONUeliBZx9T-gZGuPmQ,6799
|
182
186
|
flwr/simulation/ray_transport/utils.py,sha256=TYdtfg1P9VfTdLMOJlifInGpxWHYs9UfUqIv2wfkRLA,2392
|
183
|
-
|
184
|
-
flwr_nightly-1.8.0.
|
185
|
-
flwr_nightly-1.8.0.
|
186
|
-
flwr_nightly-1.8.0.
|
187
|
-
flwr_nightly-1.8.0.
|
187
|
+
flwr/simulation/run_simulation.py,sha256=NYUFJ6cG5QtuwEl6f2IWJNMRo_jDmmA41DABqISpq-Q,5950
|
188
|
+
flwr_nightly-1.8.0.dev20240229.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
189
|
+
flwr_nightly-1.8.0.dev20240229.dist-info/METADATA,sha256=bnkE4rD-5EY00NUOdT-ZzoeuextoRXX3zZL0MLR6PbU,15184
|
190
|
+
flwr_nightly-1.8.0.dev20240229.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
|
191
|
+
flwr_nightly-1.8.0.dev20240229.dist-info/entry_points.txt,sha256=qz6t0YdMrV_PLbEarJ6ITIWSIRrTbG_jossZAYfXBZQ,311
|
192
|
+
flwr_nightly-1.8.0.dev20240229.dist-info/RECORD,,
|
@@ -3,6 +3,7 @@ flower-client-app=flwr.client:run_client_app
|
|
3
3
|
flower-driver-api=flwr.server:run_driver_api
|
4
4
|
flower-fleet-api=flwr.server:run_fleet_api
|
5
5
|
flower-server-app=flwr.server:run_server_app
|
6
|
+
flower-simulation=flwr.simulation:run_simulation
|
6
7
|
flower-superlink=flwr.server:run_superlink
|
7
8
|
flwr=flwr.cli.app:app
|
8
9
|
|
{flwr_nightly-1.8.0.dev20240228.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE
RENAMED
File without changes
|
File without changes
|