flwr-nightly 1.8.0.dev20240228__py3-none-any.whl → 1.8.0.dev20240229__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +55 -75
- flwr/common/secure_aggregation/secaggplus_constants.py +49 -27
- flwr/proto/error_pb2.py +26 -0
- flwr/proto/error_pb2.pyi +25 -0
- flwr/proto/error_pb2_grpc.py +4 -0
- flwr/proto/error_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +8 -7
- flwr/proto/task_pb2.pyi +7 -2
- flwr/server/strategy/__init__.py +11 -3
- flwr/server/strategy/dp_adaptive_clipping.py +205 -1
- flwr/server/superlink/driver/driver_servicer.py +4 -4
- flwr/simulation/__init__.py +3 -0
- flwr/simulation/run_simulation.py +177 -0
- {flwr_nightly-1.8.0.dev20240228.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/METADATA +1 -1
- {flwr_nightly-1.8.0.dev20240228.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD +18 -13
- {flwr_nightly-1.8.0.dev20240228.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.8.0.dev20240228.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240228.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/WHEEL +0 -0
@@ -52,31 +52,10 @@ from flwr.common.secure_aggregation.ndarrays_arithmetic import (
|
|
52
52
|
)
|
53
53
|
from flwr.common.secure_aggregation.quantization import quantize
|
54
54
|
from flwr.common.secure_aggregation.secaggplus_constants import (
|
55
|
-
KEY_ACTIVE_SECURE_ID_LIST,
|
56
|
-
KEY_CIPHERTEXT_LIST,
|
57
|
-
KEY_CLIPPING_RANGE,
|
58
|
-
KEY_DEAD_SECURE_ID_LIST,
|
59
|
-
KEY_DESTINATION_LIST,
|
60
|
-
KEY_MASKED_PARAMETERS,
|
61
|
-
KEY_MOD_RANGE,
|
62
|
-
KEY_PUBLIC_KEY_1,
|
63
|
-
KEY_PUBLIC_KEY_2,
|
64
|
-
KEY_SAMPLE_NUMBER,
|
65
|
-
KEY_SECURE_ID,
|
66
|
-
KEY_SECURE_ID_LIST,
|
67
|
-
KEY_SHARE_LIST,
|
68
|
-
KEY_SHARE_NUMBER,
|
69
|
-
KEY_SOURCE_LIST,
|
70
|
-
KEY_STAGE,
|
71
|
-
KEY_TARGET_RANGE,
|
72
|
-
KEY_THRESHOLD,
|
73
55
|
RECORD_KEY_CONFIGS,
|
74
56
|
RECORD_KEY_STATE,
|
75
|
-
|
76
|
-
|
77
|
-
STAGE_SHARE_KEYS,
|
78
|
-
STAGE_UNMASK,
|
79
|
-
STAGES,
|
57
|
+
Key,
|
58
|
+
Stage,
|
80
59
|
)
|
81
60
|
from flwr.common.secure_aggregation.secaggplus_utils import (
|
82
61
|
pseudo_rand_gen,
|
@@ -91,7 +70,7 @@ from flwr.common.typing import ConfigsRecordValues, FitRes
|
|
91
70
|
class SecAggPlusState:
|
92
71
|
"""State of the SecAgg+ protocol."""
|
93
72
|
|
94
|
-
current_stage: str =
|
73
|
+
current_stage: str = Stage.UNMASK
|
95
74
|
|
96
75
|
sid: int = 0
|
97
76
|
sample_num: int = 0
|
@@ -187,20 +166,20 @@ def secaggplus_mod(
|
|
187
166
|
check_stage(state.current_stage, configs)
|
188
167
|
|
189
168
|
# Update the current stage
|
190
|
-
state.current_stage = cast(str, configs.pop(
|
169
|
+
state.current_stage = cast(str, configs.pop(Key.STAGE))
|
191
170
|
|
192
171
|
# Check the validity of the configs based on the current stage
|
193
172
|
check_configs(state.current_stage, configs)
|
194
173
|
|
195
174
|
# Execute
|
196
|
-
if state.current_stage ==
|
175
|
+
if state.current_stage == Stage.SETUP:
|
197
176
|
res = _setup(state, configs)
|
198
|
-
elif state.current_stage ==
|
177
|
+
elif state.current_stage == Stage.SHARE_KEYS:
|
199
178
|
res = _share_keys(state, configs)
|
200
|
-
elif state.current_stage ==
|
179
|
+
elif state.current_stage == Stage.COLLECT_MASKED_INPUT:
|
201
180
|
fit = _get_fit_fn(msg, ctxt, call_next)
|
202
181
|
res = _collect_masked_input(state, configs, fit)
|
203
|
-
elif state.current_stage ==
|
182
|
+
elif state.current_stage == Stage.UNMASK:
|
204
183
|
res = _unmask(state, configs)
|
205
184
|
else:
|
206
185
|
raise ValueError(f"Unknown secagg stage: {state.current_stage}")
|
@@ -215,28 +194,29 @@ def secaggplus_mod(
|
|
215
194
|
|
216
195
|
def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
217
196
|
"""Check the validity of the next stage."""
|
218
|
-
# Check the existence of
|
219
|
-
if
|
197
|
+
# Check the existence of Config.STAGE
|
198
|
+
if Key.STAGE not in configs:
|
220
199
|
raise KeyError(
|
221
|
-
f"The required key '{
|
200
|
+
f"The required key '{Key.STAGE}' is missing from the input `named_values`."
|
222
201
|
)
|
223
202
|
|
224
|
-
# Check the value type of the
|
225
|
-
next_stage = configs[
|
203
|
+
# Check the value type of the Config.STAGE
|
204
|
+
next_stage = configs[Key.STAGE]
|
226
205
|
if not isinstance(next_stage, str):
|
227
206
|
raise TypeError(
|
228
|
-
f"The value for the key '{
|
207
|
+
f"The value for the key '{Key.STAGE}' must be of type {str}, "
|
229
208
|
f"but got {type(next_stage)} instead."
|
230
209
|
)
|
231
210
|
|
232
211
|
# Check the validity of the next stage
|
233
|
-
if next_stage ==
|
234
|
-
if current_stage !=
|
212
|
+
if next_stage == Stage.SETUP:
|
213
|
+
if current_stage != Stage.UNMASK:
|
235
214
|
log(WARNING, "Restart from the setup stage")
|
236
215
|
# If stage is not "setup",
|
237
216
|
# the stage from `named_values` should be the expected next stage
|
238
217
|
else:
|
239
|
-
|
218
|
+
stages = Stage.all()
|
219
|
+
expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)]
|
240
220
|
if next_stage != expected_next_stage:
|
241
221
|
raise ValueError(
|
242
222
|
"Abort secure aggregation: "
|
@@ -248,20 +228,20 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
|
248
228
|
def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
249
229
|
"""Check the validity of the configs."""
|
250
230
|
# Check `named_values` for the setup stage
|
251
|
-
if stage ==
|
231
|
+
if stage == Stage.SETUP:
|
252
232
|
key_type_pairs = [
|
253
|
-
(
|
254
|
-
(
|
255
|
-
(
|
256
|
-
(
|
257
|
-
(
|
258
|
-
(
|
259
|
-
(
|
233
|
+
(Key.SAMPLE_NUMBER, int),
|
234
|
+
(Key.SECURE_ID, int),
|
235
|
+
(Key.SHARE_NUMBER, int),
|
236
|
+
(Key.THRESHOLD, int),
|
237
|
+
(Key.CLIPPING_RANGE, float),
|
238
|
+
(Key.TARGET_RANGE, int),
|
239
|
+
(Key.MOD_RANGE, int),
|
260
240
|
]
|
261
241
|
for key, expected_type in key_type_pairs:
|
262
242
|
if key not in configs:
|
263
243
|
raise KeyError(
|
264
|
-
f"Stage {
|
244
|
+
f"Stage {Stage.SETUP}: the required key '{key}' is "
|
265
245
|
"missing from the input `named_values`."
|
266
246
|
)
|
267
247
|
# Bool is a subclass of int in Python,
|
@@ -269,11 +249,11 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
269
249
|
# pylint: disable-next=unidiomatic-typecheck
|
270
250
|
if type(configs[key]) is not expected_type:
|
271
251
|
raise TypeError(
|
272
|
-
f"Stage {
|
252
|
+
f"Stage {Stage.SETUP}: The value for the key '{key}' "
|
273
253
|
f"must be of type {expected_type}, "
|
274
254
|
f"but got {type(configs[key])} instead."
|
275
255
|
)
|
276
|
-
elif stage ==
|
256
|
+
elif stage == Stage.SHARE_KEYS:
|
277
257
|
for key, value in configs.items():
|
278
258
|
if (
|
279
259
|
not isinstance(value, list)
|
@@ -282,18 +262,18 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
282
262
|
or not isinstance(value[1], bytes)
|
283
263
|
):
|
284
264
|
raise TypeError(
|
285
|
-
f"Stage {
|
265
|
+
f"Stage {Stage.SHARE_KEYS}: "
|
286
266
|
f"the value for the key '{key}' must be a list of two bytes."
|
287
267
|
)
|
288
|
-
elif stage ==
|
268
|
+
elif stage == Stage.COLLECT_MASKED_INPUT:
|
289
269
|
key_type_pairs = [
|
290
|
-
(
|
291
|
-
(
|
270
|
+
(Key.CIPHERTEXT_LIST, bytes),
|
271
|
+
(Key.SOURCE_LIST, int),
|
292
272
|
]
|
293
273
|
for key, expected_type in key_type_pairs:
|
294
274
|
if key not in configs:
|
295
275
|
raise KeyError(
|
296
|
-
f"Stage {
|
276
|
+
f"Stage {Stage.COLLECT_MASKED_INPUT}: "
|
297
277
|
f"the required key '{key}' is "
|
298
278
|
"missing from the input `named_values`."
|
299
279
|
)
|
@@ -304,19 +284,19 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
304
284
|
if type(elm) is not expected_type
|
305
285
|
):
|
306
286
|
raise TypeError(
|
307
|
-
f"Stage {
|
287
|
+
f"Stage {Stage.COLLECT_MASKED_INPUT}: "
|
308
288
|
f"the value for the key '{key}' "
|
309
289
|
f"must be of type List[{expected_type.__name__}]"
|
310
290
|
)
|
311
|
-
elif stage ==
|
291
|
+
elif stage == Stage.UNMASK:
|
312
292
|
key_type_pairs = [
|
313
|
-
(
|
314
|
-
(
|
293
|
+
(Key.ACTIVE_SECURE_ID_LIST, int),
|
294
|
+
(Key.DEAD_SECURE_ID_LIST, int),
|
315
295
|
]
|
316
296
|
for key, expected_type in key_type_pairs:
|
317
297
|
if key not in configs:
|
318
298
|
raise KeyError(
|
319
|
-
f"Stage {
|
299
|
+
f"Stage {Stage.UNMASK}: "
|
320
300
|
f"the required key '{key}' is "
|
321
301
|
"missing from the input `named_values`."
|
322
302
|
)
|
@@ -327,7 +307,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
|
|
327
307
|
if type(elm) is not expected_type
|
328
308
|
):
|
329
309
|
raise TypeError(
|
330
|
-
f"Stage {
|
310
|
+
f"Stage {Stage.UNMASK}: "
|
331
311
|
f"the value for the key '{key}' "
|
332
312
|
f"must be of type List[{expected_type.__name__}]"
|
333
313
|
)
|
@@ -340,15 +320,15 @@ def _setup(
|
|
340
320
|
) -> Dict[str, ConfigsRecordValues]:
|
341
321
|
# Assigning parameter values to object fields
|
342
322
|
sec_agg_param_dict = configs
|
343
|
-
state.sample_num = cast(int, sec_agg_param_dict[
|
344
|
-
state.sid = cast(int, sec_agg_param_dict[
|
323
|
+
state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
|
324
|
+
state.sid = cast(int, sec_agg_param_dict[Key.SECURE_ID])
|
345
325
|
log(INFO, "Client %d: starting stage 0...", state.sid)
|
346
326
|
|
347
|
-
state.share_num = cast(int, sec_agg_param_dict[
|
348
|
-
state.threshold = cast(int, sec_agg_param_dict[
|
349
|
-
state.clipping_range = cast(float, sec_agg_param_dict[
|
350
|
-
state.target_range = cast(int, sec_agg_param_dict[
|
351
|
-
state.mod_range = cast(int, sec_agg_param_dict[
|
327
|
+
state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER])
|
328
|
+
state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD])
|
329
|
+
state.clipping_range = cast(float, sec_agg_param_dict[Key.CLIPPING_RANGE])
|
330
|
+
state.target_range = cast(int, sec_agg_param_dict[Key.TARGET_RANGE])
|
331
|
+
state.mod_range = cast(int, sec_agg_param_dict[Key.MOD_RANGE])
|
352
332
|
|
353
333
|
# Dictionaries containing client secure IDs as keys
|
354
334
|
# and their respective secret shares as values.
|
@@ -367,7 +347,7 @@ def _setup(
|
|
367
347
|
state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1)
|
368
348
|
state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2)
|
369
349
|
log(INFO, "Client %d: stage 0 completes. uploading public keys...", state.sid)
|
370
|
-
return {
|
350
|
+
return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2}
|
371
351
|
|
372
352
|
|
373
353
|
# pylint: disable-next=too-many-locals
|
@@ -429,7 +409,7 @@ def _share_keys(
|
|
429
409
|
ciphertexts.append(ciphertext)
|
430
410
|
|
431
411
|
log(INFO, "Client %d: stage 1 completes. uploading key shares...", state.sid)
|
432
|
-
return {
|
412
|
+
return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts}
|
433
413
|
|
434
414
|
|
435
415
|
# pylint: disable-next=too-many-locals
|
@@ -440,8 +420,8 @@ def _collect_masked_input(
|
|
440
420
|
) -> Dict[str, ConfigsRecordValues]:
|
441
421
|
log(INFO, "Client %d: starting stage 2...", state.sid)
|
442
422
|
available_clients: List[int] = []
|
443
|
-
ciphertexts = cast(List[bytes], configs[
|
444
|
-
srcs = cast(List[int], configs[
|
423
|
+
ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST])
|
424
|
+
srcs = cast(List[int], configs[Key.SOURCE_LIST])
|
445
425
|
if len(ciphertexts) + 1 < state.threshold:
|
446
426
|
raise ValueError("Not enough available neighbour clients.")
|
447
427
|
|
@@ -505,7 +485,7 @@ def _collect_masked_input(
|
|
505
485
|
quantized_parameters = parameters_mod(quantized_parameters, state.mod_range)
|
506
486
|
log(INFO, "Client %d: stage 2 completes. uploading masked parameters...", state.sid)
|
507
487
|
return {
|
508
|
-
|
488
|
+
Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters]
|
509
489
|
}
|
510
490
|
|
511
491
|
|
@@ -514,8 +494,8 @@ def _unmask(
|
|
514
494
|
) -> Dict[str, ConfigsRecordValues]:
|
515
495
|
log(INFO, "Client %d: starting stage 3...", state.sid)
|
516
496
|
|
517
|
-
active_sids = cast(List[int], configs[
|
518
|
-
dead_sids = cast(List[int], configs[
|
497
|
+
active_sids = cast(List[int], configs[Key.ACTIVE_SECURE_ID_LIST])
|
498
|
+
dead_sids = cast(List[int], configs[Key.DEAD_SECURE_ID_LIST])
|
519
499
|
# Send private mask seed share for every avaliable client (including itclient)
|
520
500
|
# Send first private key share for building pairwise mask for every dropped client
|
521
501
|
if len(active_sids) < state.threshold:
|
@@ -528,4 +508,4 @@ def _unmask(
|
|
528
508
|
shares += [state.sk1_share_dict[sid] for sid in dead_sids]
|
529
509
|
|
530
510
|
log(INFO, "Client %d: stage 3 completes. uploading key shares...", state.sid)
|
531
|
-
return {
|
511
|
+
return {Key.SECURE_ID_LIST: sids, Key.SHARE_LIST: shares}
|
@@ -14,33 +14,55 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
"""Constants for the SecAgg/SecAgg+ protocol."""
|
16
16
|
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
17
20
|
RECORD_KEY_STATE = "secaggplus_state"
|
18
21
|
RECORD_KEY_CONFIGS = "secaggplus_configs"
|
19
22
|
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
23
|
+
|
24
|
+
class Stage:
|
25
|
+
"""Stages for the SecAgg+ protocol."""
|
26
|
+
|
27
|
+
SETUP = "setup"
|
28
|
+
SHARE_KEYS = "share_keys"
|
29
|
+
COLLECT_MASKED_INPUT = "collect_masked_input"
|
30
|
+
UNMASK = "unmask"
|
31
|
+
_stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_INPUT, UNMASK)
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def all(cls) -> tuple[str, str, str, str]:
|
35
|
+
"""Return all stages."""
|
36
|
+
return cls._stages
|
37
|
+
|
38
|
+
def __new__(cls) -> Stage:
|
39
|
+
"""Prevent instantiation."""
|
40
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
41
|
+
|
42
|
+
|
43
|
+
class Key:
|
44
|
+
"""Keys for the configs in the ConfigsRecord."""
|
45
|
+
|
46
|
+
STAGE = "stage"
|
47
|
+
SAMPLE_NUMBER = "sample_num"
|
48
|
+
SECURE_ID = "secure_id"
|
49
|
+
SHARE_NUMBER = "share_num"
|
50
|
+
THRESHOLD = "threshold"
|
51
|
+
CLIPPING_RANGE = "clipping_range"
|
52
|
+
TARGET_RANGE = "target_range"
|
53
|
+
MOD_RANGE = "mod_range"
|
54
|
+
PUBLIC_KEY_1 = "pk1"
|
55
|
+
PUBLIC_KEY_2 = "pk2"
|
56
|
+
DESTINATION_LIST = "dsts"
|
57
|
+
CIPHERTEXT_LIST = "ctxts"
|
58
|
+
SOURCE_LIST = "srcs"
|
59
|
+
PARAMETERS = "params"
|
60
|
+
MASKED_PARAMETERS = "masked_params"
|
61
|
+
ACTIVE_SECURE_ID_LIST = "active_sids"
|
62
|
+
DEAD_SECURE_ID_LIST = "dead_sids"
|
63
|
+
SECURE_ID_LIST = "sids"
|
64
|
+
SHARE_LIST = "shares"
|
65
|
+
|
66
|
+
def __new__(cls) -> Key:
|
67
|
+
"""Prevent instantiation."""
|
68
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
flwr/proto/error_pb2.py
ADDED
@@ -0,0 +1,26 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3
|
+
# source: flwr/proto/error.proto
|
4
|
+
# Protobuf Python Version: 4.25.0
|
5
|
+
"""Generated protocol buffer code."""
|
6
|
+
from google.protobuf import descriptor as _descriptor
|
7
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
8
|
+
from google.protobuf import symbol_database as _symbol_database
|
9
|
+
from google.protobuf.internal import builder as _builder
|
10
|
+
# @@protoc_insertion_point(imports)
|
11
|
+
|
12
|
+
_sym_db = _symbol_database.Default()
|
13
|
+
|
14
|
+
|
15
|
+
|
16
|
+
|
17
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/error.proto\x12\nflwr.proto\"%\n\x05\x45rror\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x12\x12\x0e\n\x06reason\x18\x02 \x01(\tb\x06proto3')
|
18
|
+
|
19
|
+
_globals = globals()
|
20
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
21
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.error_pb2', _globals)
|
22
|
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
23
|
+
DESCRIPTOR._options = None
|
24
|
+
_globals['_ERROR']._serialized_start=38
|
25
|
+
_globals['_ERROR']._serialized_end=75
|
26
|
+
# @@protoc_insertion_point(module_scope)
|
flwr/proto/error_pb2.pyi
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
"""
|
2
|
+
@generated by mypy-protobuf. Do not edit manually!
|
3
|
+
isort:skip_file
|
4
|
+
"""
|
5
|
+
import builtins
|
6
|
+
import google.protobuf.descriptor
|
7
|
+
import google.protobuf.message
|
8
|
+
import typing
|
9
|
+
import typing_extensions
|
10
|
+
|
11
|
+
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
12
|
+
|
13
|
+
class Error(google.protobuf.message.Message):
|
14
|
+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
15
|
+
CODE_FIELD_NUMBER: builtins.int
|
16
|
+
REASON_FIELD_NUMBER: builtins.int
|
17
|
+
code: builtins.int
|
18
|
+
reason: typing.Text
|
19
|
+
def __init__(self,
|
20
|
+
*,
|
21
|
+
code: builtins.int = ...,
|
22
|
+
reason: typing.Text = ...,
|
23
|
+
) -> None: ...
|
24
|
+
def ClearField(self, field_name: typing_extensions.Literal["code",b"code","reason",b"reason"]) -> None: ...
|
25
|
+
global___Error = Error
|
flwr/proto/task_pb2.py
CHANGED
@@ -15,19 +15,20 @@ _sym_db = _symbol_database.Default()
|
|
15
15
|
from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
16
16
|
from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
|
17
17
|
from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
|
18
|
+
from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2
|
18
19
|
|
19
20
|
|
20
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\
|
21
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\xf6\x01\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\t \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
|
21
22
|
|
22
23
|
_globals = globals()
|
23
24
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
24
25
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals)
|
25
26
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
26
27
|
DESCRIPTOR._options = None
|
27
|
-
_globals['_TASK']._serialized_start=
|
28
|
-
_globals['_TASK']._serialized_end=
|
29
|
-
_globals['_TASKINS']._serialized_start=
|
30
|
-
_globals['_TASKINS']._serialized_end=
|
31
|
-
_globals['_TASKRES']._serialized_start=
|
32
|
-
_globals['_TASKRES']._serialized_end=
|
28
|
+
_globals['_TASK']._serialized_start=141
|
29
|
+
_globals['_TASK']._serialized_end=387
|
30
|
+
_globals['_TASKINS']._serialized_start=389
|
31
|
+
_globals['_TASKINS']._serialized_end=481
|
32
|
+
_globals['_TASKRES']._serialized_start=483
|
33
|
+
_globals['_TASKRES']._serialized_end=575
|
33
34
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/task_pb2.pyi
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
isort:skip_file
|
4
4
|
"""
|
5
5
|
import builtins
|
6
|
+
import flwr.proto.error_pb2
|
6
7
|
import flwr.proto.node_pb2
|
7
8
|
import flwr.proto.recordset_pb2
|
8
9
|
import google.protobuf.descriptor
|
@@ -23,6 +24,7 @@ class Task(google.protobuf.message.Message):
|
|
23
24
|
ANCESTRY_FIELD_NUMBER: builtins.int
|
24
25
|
TASK_TYPE_FIELD_NUMBER: builtins.int
|
25
26
|
RECORDSET_FIELD_NUMBER: builtins.int
|
27
|
+
ERROR_FIELD_NUMBER: builtins.int
|
26
28
|
@property
|
27
29
|
def producer(self) -> flwr.proto.node_pb2.Node: ...
|
28
30
|
@property
|
@@ -35,6 +37,8 @@ class Task(google.protobuf.message.Message):
|
|
35
37
|
task_type: typing.Text
|
36
38
|
@property
|
37
39
|
def recordset(self) -> flwr.proto.recordset_pb2.RecordSet: ...
|
40
|
+
@property
|
41
|
+
def error(self) -> flwr.proto.error_pb2.Error: ...
|
38
42
|
def __init__(self,
|
39
43
|
*,
|
40
44
|
producer: typing.Optional[flwr.proto.node_pb2.Node] = ...,
|
@@ -45,9 +49,10 @@ class Task(google.protobuf.message.Message):
|
|
45
49
|
ancestry: typing.Optional[typing.Iterable[typing.Text]] = ...,
|
46
50
|
task_type: typing.Text = ...,
|
47
51
|
recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ...,
|
52
|
+
error: typing.Optional[flwr.proto.error_pb2.Error] = ...,
|
48
53
|
) -> None: ...
|
49
|
-
def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
|
50
|
-
def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
|
54
|
+
def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","error",b"error","producer",b"producer","recordset",b"recordset"]) -> builtins.bool: ...
|
55
|
+
def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","error",b"error","producer",b"producer","recordset",b"recordset","task_type",b"task_type","ttl",b"ttl"]) -> None: ...
|
51
56
|
global___Task = Task
|
52
57
|
|
53
58
|
class TaskIns(google.protobuf.message.Message):
|
flwr/server/strategy/__init__.py
CHANGED
@@ -16,10 +16,17 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
from .bulyan import Bulyan as Bulyan
|
19
|
-
from .dp_adaptive_clipping import
|
19
|
+
from .dp_adaptive_clipping import (
|
20
|
+
DifferentialPrivacyClientSideAdaptiveClipping as DifferentialPrivacyClientSideAdaptiveClipping,
|
21
|
+
)
|
22
|
+
from .dp_adaptive_clipping import (
|
23
|
+
DifferentialPrivacyServerSideAdaptiveClipping as DifferentialPrivacyServerSideAdaptiveClipping,
|
24
|
+
)
|
25
|
+
from .dp_fixed_clipping import (
|
26
|
+
DifferentialPrivacyClientSideFixedClipping as DifferentialPrivacyClientSideFixedClipping,
|
27
|
+
)
|
20
28
|
from .dp_fixed_clipping import (
|
21
|
-
|
22
|
-
DifferentialPrivacyServerSideFixedClipping,
|
29
|
+
DifferentialPrivacyServerSideFixedClipping as DifferentialPrivacyServerSideFixedClipping,
|
23
30
|
)
|
24
31
|
from .dpfedavg_adaptive import DPFedAvgAdaptive as DPFedAvgAdaptive
|
25
32
|
from .dpfedavg_fixed import DPFedAvgFixed as DPFedAvgFixed
|
@@ -46,6 +53,7 @@ __all__ = [
|
|
46
53
|
"DPFedAvgAdaptive",
|
47
54
|
"DPFedAvgFixed",
|
48
55
|
"DifferentialPrivacyClientSideAdaptiveClipping",
|
56
|
+
"DifferentialPrivacyServerSideAdaptiveClipping",
|
49
57
|
"DifferentialPrivacyClientSideFixedClipping",
|
50
58
|
"DifferentialPrivacyServerSideFixedClipping",
|
51
59
|
"FedAdagrad",
|
@@ -24,8 +24,19 @@ from typing import Dict, List, Optional, Tuple, Union
|
|
24
24
|
|
25
25
|
import numpy as np
|
26
26
|
|
27
|
-
from flwr.common import
|
27
|
+
from flwr.common import (
|
28
|
+
EvaluateIns,
|
29
|
+
EvaluateRes,
|
30
|
+
FitIns,
|
31
|
+
FitRes,
|
32
|
+
NDArrays,
|
33
|
+
Parameters,
|
34
|
+
Scalar,
|
35
|
+
ndarrays_to_parameters,
|
36
|
+
parameters_to_ndarrays,
|
37
|
+
)
|
28
38
|
from flwr.common.differential_privacy import (
|
39
|
+
adaptive_clip_inputs_inplace,
|
29
40
|
add_gaussian_noise_to_params,
|
30
41
|
compute_adaptive_noise_params,
|
31
42
|
)
|
@@ -40,6 +51,199 @@ from flwr.server.client_proxy import ClientProxy
|
|
40
51
|
from flwr.server.strategy.strategy import Strategy
|
41
52
|
|
42
53
|
|
54
|
+
class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
55
|
+
"""Strategy wrapper for central DP with server-side adaptive clipping.
|
56
|
+
|
57
|
+
Parameters
|
58
|
+
----------
|
59
|
+
strategy: Strategy
|
60
|
+
The strategy to which DP functionalities will be added by this wrapper.
|
61
|
+
noise_multiplier : float
|
62
|
+
The noise multiplier for the Gaussian mechanism for model updates.
|
63
|
+
num_sampled_clients : int
|
64
|
+
The number of clients that are sampled on each round.
|
65
|
+
initial_clipping_norm : float
|
66
|
+
The initial value of clipping norm. Deafults to 0.1.
|
67
|
+
Andrew et al. recommends to set to 0.1.
|
68
|
+
target_clipped_quantile : float
|
69
|
+
The desired quantile of updates which should be clipped. Defaults to 0.5.
|
70
|
+
clip_norm_lr : float
|
71
|
+
The learning rate for the clipping norm adaptation. Defaults to 0.2.
|
72
|
+
Andrew et al. recommends to set to 0.2.
|
73
|
+
clipped_count_stddev : float
|
74
|
+
The standard deviation of the noise added to the count of updates below the estimate.
|
75
|
+
Andrew et al. recommends to set to `expected_num_records/20`
|
76
|
+
|
77
|
+
Examples
|
78
|
+
--------
|
79
|
+
Create a strategy:
|
80
|
+
|
81
|
+
>>> strategy = fl.server.strategy.FedAvg( ... )
|
82
|
+
|
83
|
+
Wrap the strategy with the DifferentialPrivacyServerSideAdaptiveClipping wrapper
|
84
|
+
|
85
|
+
>>> dp_strategy = DifferentialPrivacyServerSideAdaptiveClipping(
|
86
|
+
>>> strategy, cfg.noise_multiplier, cfg.num_sampled_clients, ...
|
87
|
+
>>> )
|
88
|
+
"""
|
89
|
+
|
90
|
+
# pylint: disable=too-many-arguments,too-many-instance-attributes
|
91
|
+
def __init__(
|
92
|
+
self,
|
93
|
+
strategy: Strategy,
|
94
|
+
noise_multiplier: float,
|
95
|
+
num_sampled_clients: int,
|
96
|
+
initial_clipping_norm: float = 0.1,
|
97
|
+
target_clipped_quantile: float = 0.5,
|
98
|
+
clip_norm_lr: float = 0.2,
|
99
|
+
clipped_count_stddev: Optional[float] = None,
|
100
|
+
) -> None:
|
101
|
+
super().__init__()
|
102
|
+
|
103
|
+
if strategy is None:
|
104
|
+
raise ValueError("The passed strategy is None.")
|
105
|
+
|
106
|
+
if noise_multiplier < 0:
|
107
|
+
raise ValueError("The noise multiplier should be a non-negative value.")
|
108
|
+
|
109
|
+
if num_sampled_clients <= 0:
|
110
|
+
raise ValueError(
|
111
|
+
"The number of sampled clients should be a positive value."
|
112
|
+
)
|
113
|
+
|
114
|
+
if initial_clipping_norm <= 0:
|
115
|
+
raise ValueError("The initial clipping norm should be a positive value.")
|
116
|
+
|
117
|
+
if not 0 <= target_clipped_quantile <= 1:
|
118
|
+
raise ValueError(
|
119
|
+
"The target clipped quantile must be between 0 and 1 (inclusive)."
|
120
|
+
)
|
121
|
+
|
122
|
+
if clip_norm_lr <= 0:
|
123
|
+
raise ValueError("The learning rate must be positive.")
|
124
|
+
|
125
|
+
if clipped_count_stddev is not None:
|
126
|
+
if clipped_count_stddev < 0:
|
127
|
+
raise ValueError("The `clipped_count_stddev` must be non-negative.")
|
128
|
+
|
129
|
+
self.strategy = strategy
|
130
|
+
self.num_sampled_clients = num_sampled_clients
|
131
|
+
self.clipping_norm = initial_clipping_norm
|
132
|
+
self.target_clipped_quantile = target_clipped_quantile
|
133
|
+
self.clip_norm_lr = clip_norm_lr
|
134
|
+
(
|
135
|
+
self.clipped_count_stddev,
|
136
|
+
self.noise_multiplier,
|
137
|
+
) = compute_adaptive_noise_params(
|
138
|
+
noise_multiplier,
|
139
|
+
num_sampled_clients,
|
140
|
+
clipped_count_stddev,
|
141
|
+
)
|
142
|
+
|
143
|
+
self.current_round_params: NDArrays = []
|
144
|
+
|
145
|
+
def __repr__(self) -> str:
|
146
|
+
"""Compute a string representation of the strategy."""
|
147
|
+
rep = "Differential Privacy Strategy Wrapper (Server-Side Adaptive Clipping)"
|
148
|
+
return rep
|
149
|
+
|
150
|
+
def initialize_parameters(
|
151
|
+
self, client_manager: ClientManager
|
152
|
+
) -> Optional[Parameters]:
|
153
|
+
"""Initialize global model parameters using given strategy."""
|
154
|
+
return self.strategy.initialize_parameters(client_manager)
|
155
|
+
|
156
|
+
def configure_fit(
|
157
|
+
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
158
|
+
) -> List[Tuple[ClientProxy, FitIns]]:
|
159
|
+
"""Configure the next round of training."""
|
160
|
+
self.current_round_params = parameters_to_ndarrays(parameters)
|
161
|
+
return self.strategy.configure_fit(server_round, parameters, client_manager)
|
162
|
+
|
163
|
+
def configure_evaluate(
|
164
|
+
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
165
|
+
) -> List[Tuple[ClientProxy, EvaluateIns]]:
|
166
|
+
"""Configure the next round of evaluation."""
|
167
|
+
return self.strategy.configure_evaluate(
|
168
|
+
server_round, parameters, client_manager
|
169
|
+
)
|
170
|
+
|
171
|
+
def aggregate_fit(
|
172
|
+
self,
|
173
|
+
server_round: int,
|
174
|
+
results: List[Tuple[ClientProxy, FitRes]],
|
175
|
+
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
|
176
|
+
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
|
177
|
+
"""Aggregate training results and update clip norms."""
|
178
|
+
if failures:
|
179
|
+
return None, {}
|
180
|
+
|
181
|
+
if len(results) != self.num_sampled_clients:
|
182
|
+
log(
|
183
|
+
WARNING,
|
184
|
+
CLIENTS_DISCREPANCY_WARNING,
|
185
|
+
len(results),
|
186
|
+
self.num_sampled_clients,
|
187
|
+
)
|
188
|
+
|
189
|
+
norm_bit_set_count = 0
|
190
|
+
for _, res in results:
|
191
|
+
param = parameters_to_ndarrays(res.parameters)
|
192
|
+
# Compute and clip update
|
193
|
+
model_update = [
|
194
|
+
np.subtract(x, y) for (x, y) in zip(param, self.current_round_params)
|
195
|
+
]
|
196
|
+
|
197
|
+
norm_bit = adaptive_clip_inputs_inplace(model_update, self.clipping_norm)
|
198
|
+
norm_bit_set_count += norm_bit
|
199
|
+
|
200
|
+
for i, _ in enumerate(self.current_round_params):
|
201
|
+
param[i] = self.current_round_params[i] + model_update[i]
|
202
|
+
# Convert back to parameters
|
203
|
+
res.parameters = ndarrays_to_parameters(param)
|
204
|
+
|
205
|
+
# Noising the count
|
206
|
+
noised_norm_bit_set_count = float(
|
207
|
+
np.random.normal(norm_bit_set_count, self.clipped_count_stddev)
|
208
|
+
)
|
209
|
+
noised_norm_bit_set_fraction = noised_norm_bit_set_count / len(results)
|
210
|
+
# Geometric update
|
211
|
+
self.clipping_norm *= math.exp(
|
212
|
+
-self.clip_norm_lr
|
213
|
+
* (noised_norm_bit_set_fraction - self.target_clipped_quantile)
|
214
|
+
)
|
215
|
+
|
216
|
+
aggregated_params, metrics = self.strategy.aggregate_fit(
|
217
|
+
server_round, results, failures
|
218
|
+
)
|
219
|
+
|
220
|
+
# Add Gaussian noise to the aggregated parameters
|
221
|
+
if aggregated_params:
|
222
|
+
aggregated_params = add_gaussian_noise_to_params(
|
223
|
+
aggregated_params,
|
224
|
+
self.noise_multiplier,
|
225
|
+
self.clipping_norm,
|
226
|
+
self.num_sampled_clients,
|
227
|
+
)
|
228
|
+
|
229
|
+
return aggregated_params, metrics
|
230
|
+
|
231
|
+
def aggregate_evaluate(
|
232
|
+
self,
|
233
|
+
server_round: int,
|
234
|
+
results: List[Tuple[ClientProxy, EvaluateRes]],
|
235
|
+
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
|
236
|
+
) -> Tuple[Optional[float], Dict[str, Scalar]]:
|
237
|
+
"""Aggregate evaluation losses using the given strategy."""
|
238
|
+
return self.strategy.aggregate_evaluate(server_round, results, failures)
|
239
|
+
|
240
|
+
def evaluate(
|
241
|
+
self, server_round: int, parameters: Parameters
|
242
|
+
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
|
243
|
+
"""Evaluate model parameters using an evaluation function from the strategy."""
|
244
|
+
return self.strategy.evaluate(server_round, parameters)
|
245
|
+
|
246
|
+
|
43
247
|
class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
44
248
|
"""Strategy wrapper for central DP with client-side adaptive clipping.
|
45
249
|
|
@@ -15,7 +15,7 @@
|
|
15
15
|
"""Driver API servicer."""
|
16
16
|
|
17
17
|
|
18
|
-
from logging import INFO
|
18
|
+
from logging import DEBUG, INFO
|
19
19
|
from typing import List, Optional, Set
|
20
20
|
from uuid import UUID
|
21
21
|
|
@@ -70,7 +70,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
70
70
|
self, request: PushTaskInsRequest, context: grpc.ServicerContext
|
71
71
|
) -> PushTaskInsResponse:
|
72
72
|
"""Push a set of TaskIns."""
|
73
|
-
log(
|
73
|
+
log(DEBUG, "DriverServicer.PushTaskIns")
|
74
74
|
|
75
75
|
# Validate request
|
76
76
|
_raise_if(len(request.task_ins_list) == 0, "`task_ins_list` must not be empty")
|
@@ -95,7 +95,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
95
95
|
self, request: PullTaskResRequest, context: grpc.ServicerContext
|
96
96
|
) -> PullTaskResResponse:
|
97
97
|
"""Pull a set of TaskRes."""
|
98
|
-
log(
|
98
|
+
log(DEBUG, "DriverServicer.PullTaskRes")
|
99
99
|
|
100
100
|
# Convert each task_id str to UUID
|
101
101
|
task_ids: Set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
@@ -105,7 +105,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
105
105
|
|
106
106
|
# Register callback
|
107
107
|
def on_rpc_done() -> None:
|
108
|
-
log(
|
108
|
+
log(DEBUG, "DriverServicer.PullTaskRes callback: delete TaskIns/TaskRes")
|
109
109
|
|
110
110
|
if context.is_active():
|
111
111
|
return
|
flwr/simulation/__init__.py
CHANGED
@@ -17,6 +17,8 @@
|
|
17
17
|
|
18
18
|
import importlib
|
19
19
|
|
20
|
+
from flwr.simulation.run_simulation import run_simulation
|
21
|
+
|
20
22
|
is_ray_installed = importlib.util.find_spec("ray") is not None
|
21
23
|
|
22
24
|
if is_ray_installed:
|
@@ -36,4 +38,5 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
|
|
36
38
|
|
37
39
|
__all__ = [
|
38
40
|
"start_simulation",
|
41
|
+
"run_simulation",
|
39
42
|
]
|
@@ -0,0 +1,177 @@
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Flower Simulation."""
|
16
|
+
|
17
|
+
import argparse
|
18
|
+
import asyncio
|
19
|
+
import json
|
20
|
+
import threading
|
21
|
+
import traceback
|
22
|
+
from logging import ERROR, INFO, WARNING
|
23
|
+
|
24
|
+
import grpc
|
25
|
+
|
26
|
+
from flwr.common import EventType, event, log
|
27
|
+
from flwr.common.exit_handlers import register_exit_handlers
|
28
|
+
from flwr.server.driver.driver import Driver
|
29
|
+
from flwr.server.run_serverapp import run
|
30
|
+
from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
|
31
|
+
from flwr.server.superlink.fleet import vce
|
32
|
+
from flwr.server.superlink.state import StateFactory
|
33
|
+
from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
|
34
|
+
|
35
|
+
|
36
|
+
def run_simulation() -> None:
|
37
|
+
"""Run Simulation Engine."""
|
38
|
+
args = _parse_args_run_simulation().parse_args()
|
39
|
+
|
40
|
+
# Load JSON config
|
41
|
+
backend_config_dict = json.loads(args.backend_config)
|
42
|
+
|
43
|
+
# Enable GPU memory growth (relevant only for TF)
|
44
|
+
if args.enable_tf_gpu_growth:
|
45
|
+
log(INFO, "Enabling GPU growth for Tensorflow on the main thread.")
|
46
|
+
enable_tf_gpu_growth()
|
47
|
+
# Check that Backend config has also enabled using GPU growth
|
48
|
+
use_tf = backend_config_dict.get("tensorflow", False)
|
49
|
+
if not use_tf:
|
50
|
+
log(WARNING, "Enabling GPU growth for your backend.")
|
51
|
+
backend_config_dict["tensorflow"] = True
|
52
|
+
|
53
|
+
# Convert back to JSON stream
|
54
|
+
backend_config = json.dumps(backend_config_dict)
|
55
|
+
|
56
|
+
# Initialize StateFactory
|
57
|
+
state_factory = StateFactory(":flwr-in-memory-state:")
|
58
|
+
|
59
|
+
# Start Driver API
|
60
|
+
driver_server: grpc.Server = run_driver_api_grpc(
|
61
|
+
address=args.driver_api_address,
|
62
|
+
state_factory=state_factory,
|
63
|
+
certificates=None,
|
64
|
+
)
|
65
|
+
|
66
|
+
# SuperLink with Simulation Engine
|
67
|
+
f_stop = asyncio.Event()
|
68
|
+
superlink_th = threading.Thread(
|
69
|
+
target=vce.start_vce,
|
70
|
+
kwargs={
|
71
|
+
"num_supernodes": args.num_supernodes,
|
72
|
+
"client_app_module_name": args.client_app,
|
73
|
+
"backend_name": args.backend,
|
74
|
+
"backend_config_json_stream": backend_config,
|
75
|
+
"working_dir": args.dir,
|
76
|
+
"state_factory": state_factory,
|
77
|
+
"f_stop": f_stop,
|
78
|
+
},
|
79
|
+
daemon=False,
|
80
|
+
)
|
81
|
+
|
82
|
+
superlink_th.start()
|
83
|
+
event(EventType.RUN_SUPERLINK_ENTER)
|
84
|
+
|
85
|
+
try:
|
86
|
+
# Initialize Driver
|
87
|
+
driver = Driver(
|
88
|
+
driver_service_address=args.driver_api_address,
|
89
|
+
root_certificates=None,
|
90
|
+
)
|
91
|
+
|
92
|
+
# Launch server app
|
93
|
+
run(args.server_app, driver, args.dir)
|
94
|
+
|
95
|
+
except Exception as ex:
|
96
|
+
|
97
|
+
log(ERROR, "An exception occurred: %s", ex)
|
98
|
+
log(ERROR, traceback.format_exc())
|
99
|
+
raise RuntimeError(
|
100
|
+
"An error was encountered by the Simulation Engine. Ending simulation."
|
101
|
+
) from ex
|
102
|
+
|
103
|
+
finally:
|
104
|
+
|
105
|
+
del driver
|
106
|
+
|
107
|
+
# Trigger stop event
|
108
|
+
f_stop.set()
|
109
|
+
|
110
|
+
register_exit_handlers(
|
111
|
+
grpc_servers=[driver_server],
|
112
|
+
bckg_threads=[superlink_th],
|
113
|
+
event_type=EventType.RUN_SUPERLINK_LEAVE,
|
114
|
+
)
|
115
|
+
superlink_th.join()
|
116
|
+
|
117
|
+
|
118
|
+
def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
119
|
+
"""Parse flower-simulation command line arguments."""
|
120
|
+
parser = argparse.ArgumentParser(
|
121
|
+
description="Start a Flower simulation",
|
122
|
+
)
|
123
|
+
parser.add_argument(
|
124
|
+
"--client-app",
|
125
|
+
required=True,
|
126
|
+
help="For example: `client:app` or `project.package.module:wrapper.app`",
|
127
|
+
)
|
128
|
+
parser.add_argument(
|
129
|
+
"--server-app",
|
130
|
+
required=True,
|
131
|
+
help="For example: `server:app` or `project.package.module:wrapper.app`",
|
132
|
+
)
|
133
|
+
parser.add_argument(
|
134
|
+
"--driver-api-address",
|
135
|
+
default="0.0.0.0:9091",
|
136
|
+
type=str,
|
137
|
+
help="For example: `server:app` or `project.package.module:wrapper.app`",
|
138
|
+
)
|
139
|
+
parser.add_argument(
|
140
|
+
"--num-supernodes",
|
141
|
+
type=int,
|
142
|
+
required=True,
|
143
|
+
help="Number of simulated SuperNodes.",
|
144
|
+
)
|
145
|
+
parser.add_argument(
|
146
|
+
"--backend",
|
147
|
+
default="ray",
|
148
|
+
type=str,
|
149
|
+
help="Simulation backend that executes the ClientApp.",
|
150
|
+
)
|
151
|
+
parser.add_argument(
|
152
|
+
"--enable-tf-gpu-growth",
|
153
|
+
action="store_true",
|
154
|
+
help="Enables GPU growth on the main thread. This is desirable if you make "
|
155
|
+
"use of a TensorFlow model on your `ServerApp` while having your `ClientApp` "
|
156
|
+
"running on the same GPU. Without enabling this, you might encounter an "
|
157
|
+
"out-of-memory error becasue TensorFlow by default allocates all GPU memory."
|
158
|
+
"Read mor about how `tf.config.experimental.set_memory_growth()` works in "
|
159
|
+
"the TensorFlow documentation: https://www.tensorflow.org/api/stable.",
|
160
|
+
)
|
161
|
+
parser.add_argument(
|
162
|
+
"--backend-config",
|
163
|
+
type=str,
|
164
|
+
default='{"client_resources": {"num_cpus":2, "num_gpus":0.0}, "tensorflow": 0}',
|
165
|
+
help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
|
166
|
+
"configure a backend. Values supported in <value> are those included by "
|
167
|
+
"`flwr.common.typing.ConfigsRecordValues`. ",
|
168
|
+
)
|
169
|
+
parser.add_argument(
|
170
|
+
"--dir",
|
171
|
+
default="",
|
172
|
+
help="Add specified directory to the PYTHONPATH and load"
|
173
|
+
"ClientApp and ServerApp from there."
|
174
|
+
" Default: current working directory.",
|
175
|
+
)
|
176
|
+
|
177
|
+
return parser
|
{flwr_nightly-1.8.0.dev20240228.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/RECORD
RENAMED
@@ -32,7 +32,7 @@ flwr/client/message_handler/task_handler.py,sha256=ZDJBKmrn2grRMNl1rU1iGs7FiMHL5
|
|
32
32
|
flwr/client/mod/__init__.py,sha256=w6r7n6fWIrrm4lEk36lh9f1Ix6LXgAzQUrgjmMspY98,961
|
33
33
|
flwr/client/mod/centraldp_mods.py,sha256=aHbzGjSbyRENuU5vzad_tkJ9UDb48uHEvUq-zgydBwo,4954
|
34
34
|
flwr/client/mod/secure_aggregation/__init__.py,sha256=AzCdezuzX2BfXUuxVRwXdv8-zUIXoU-Bf6u4LRhzvg8,796
|
35
|
-
flwr/client/mod/secure_aggregation/secaggplus_mod.py,sha256=
|
35
|
+
flwr/client/mod/secure_aggregation/secaggplus_mod.py,sha256=Zc5b2C58SYeyQVXIbLRESOIq4rMUOuzMHNAjdSAPt6I,19434
|
36
36
|
flwr/client/mod/utils.py,sha256=lvETHcCYsSWz7h8I772hCV_kZspxqlMqzriMZ-SxmKc,1226
|
37
37
|
flwr/client/node_state.py,sha256=KTTs_l4I0jBM7IsSsbAGjhfL_yZC3QANbzyvyfZBRDM,1778
|
38
38
|
flwr/client/node_state_tests.py,sha256=gPwz0zf2iuDSa11jedkur_u3Xm7lokIDG5ALD2MCvSw,2195
|
@@ -68,7 +68,7 @@ flwr/common/secure_aggregation/crypto/shamir.py,sha256=yY35ZgHlB4YyGW_buG-1X-0M-
|
|
68
68
|
flwr/common/secure_aggregation/crypto/symmetric_encryption.py,sha256=-zDyQoTsHHQjR7o-92FNIikg1zM_Ke9yynaD5u2BXbQ,3546
|
69
69
|
flwr/common/secure_aggregation/ndarrays_arithmetic.py,sha256=KAHCEHGSTJ6mCgnC8dTpIx6URk11s5XxWTHI_7ToGIg,2979
|
70
70
|
flwr/common/secure_aggregation/quantization.py,sha256=appui7GGrkRPsupF59TkapeV4Na_CyPi73JtJ1pimdI,2310
|
71
|
-
flwr/common/secure_aggregation/secaggplus_constants.py,sha256=
|
71
|
+
flwr/common/secure_aggregation/secaggplus_constants.py,sha256=2dYMKqiO2ja8DewFeZUAGqM0xujNfyYHVwShRxeRaIM,2132
|
72
72
|
flwr/common/secure_aggregation/secaggplus_utils.py,sha256=PleDyDu7jHNAfbRoEaoQiOjxG6iMl9yA8rNKYTfnyFw,3155
|
73
73
|
flwr/common/serde.py,sha256=0tmfTcywJVLA7Hsu4nAjMn2dVMVbjYZqePJbPcDt01Y,20726
|
74
74
|
flwr/common/telemetry.py,sha256=JkFB6WBOskqAJfzSM-l6tQfRiSi2oiysClfg0-5T7NY,7782
|
@@ -79,6 +79,10 @@ flwr/proto/driver_pb2.py,sha256=JHIdjNPTgp6YHD-_lz5ZZFB0VIOR3_GmcaOTN4jndc4,3115
|
|
79
79
|
flwr/proto/driver_pb2.pyi,sha256=xwl2AqIWn0SwAlg-x5RUQeqr6DC48eywnqmD7gbaaFs,4670
|
80
80
|
flwr/proto/driver_pb2_grpc.py,sha256=qQBRdQUz4k2K4DVO7kSfWHx-62UJ85HaYKnKCr6JcU8,7304
|
81
81
|
flwr/proto/driver_pb2_grpc.pyi,sha256=NpOM5eCrIPcuWdYrZAayQSDvvFp6cDCVflabhmuvMfo,2022
|
82
|
+
flwr/proto/error_pb2.py,sha256=LarjKL90LbwkXKlhzNrDssgl4DXcvIPve8NVCXHpsKA,1084
|
83
|
+
flwr/proto/error_pb2.pyi,sha256=ZNH4HhJTU_KfMXlyCeg8FwU-fcUYxTqEmoJPtWtHikc,734
|
84
|
+
flwr/proto/error_pb2_grpc.py,sha256=1oboBPFxaTEXt9Aw7EAj8gXHDCNMhZD2VXqocC9l_gk,159
|
85
|
+
flwr/proto/error_pb2_grpc.pyi,sha256=ff2TSiLVnG6IVQcTGzb2DIH3XRSoAvAo_RMcvbMFyc0,76
|
82
86
|
flwr/proto/fleet_pb2.py,sha256=8rKQHu6Oa9ki_NG6kRNGtfPPYZp5kKBZhPW696_kn84,3852
|
83
87
|
flwr/proto/fleet_pb2.pyi,sha256=QXYs9M7_dABghdCMfk5Rjf4w0LsZGDeQ1ojH00XaQME,6182
|
84
88
|
flwr/proto/fleet_pb2_grpc.py,sha256=hF1uPaioZzQMRCP9yPlv9LC0mi_DTuhn-IkQJzWIPCs,7505
|
@@ -91,8 +95,8 @@ flwr/proto/recordset_pb2.py,sha256=un8L0kvBcgFXQIiQweOseeIJBjlOozUvQY9uTQ42Dqo,6
|
|
91
95
|
flwr/proto/recordset_pb2.pyi,sha256=NPzCJWAj1xLWzeZ_xZ6uaObQjQfWGnnqlLtn4J-SoFY,14161
|
92
96
|
flwr/proto/recordset_pb2_grpc.py,sha256=1oboBPFxaTEXt9Aw7EAj8gXHDCNMhZD2VXqocC9l_gk,159
|
93
97
|
flwr/proto/recordset_pb2_grpc.pyi,sha256=ff2TSiLVnG6IVQcTGzb2DIH3XRSoAvAo_RMcvbMFyc0,76
|
94
|
-
flwr/proto/task_pb2.py,sha256
|
95
|
-
flwr/proto/task_pb2.pyi,sha256=
|
98
|
+
flwr/proto/task_pb2.py,sha256=-UX3TqskOIRbPu8U3YwgW9ul2k9ZN6MJGgbIOX3pTqg,2431
|
99
|
+
flwr/proto/task_pb2.pyi,sha256=IgXggFya0RpL64z2o2K_qLnZHyZ1mg_WzLxFwEKrpL0,4171
|
96
100
|
flwr/proto/task_pb2_grpc.py,sha256=1oboBPFxaTEXt9Aw7EAj8gXHDCNMhZD2VXqocC9l_gk,159
|
97
101
|
flwr/proto/task_pb2_grpc.pyi,sha256=ff2TSiLVnG6IVQcTGzb2DIH3XRSoAvAo_RMcvbMFyc0,76
|
98
102
|
flwr/proto/transport_pb2.py,sha256=cURzfpCgZvH7GEvBPLvTYijE3HvhK1MePjINk4xYArk,9781
|
@@ -118,10 +122,10 @@ flwr/server/run_serverapp.py,sha256=7LLE1cVQz0Rl-hZnY7DLXvFxWCdep8xgLgEVC-yffi0,
|
|
118
122
|
flwr/server/server.py,sha256=JLc2lg-qCchD-Jyg_hTBZQN3rXsnfAGx6qJAo0vqH2Y,17812
|
119
123
|
flwr/server/server_app.py,sha256=avNQ7AMMKsn09ly81C3UBgOfHhM_R29l4MrzlalGoj8,5892
|
120
124
|
flwr/server/server_config.py,sha256=yOHpkdyuhOm--Gy_4Vofvu6jCDxhyECEDpIy02beuCg,1018
|
121
|
-
flwr/server/strategy/__init__.py,sha256=
|
125
|
+
flwr/server/strategy/__init__.py,sha256=7eVZ3hQEg2BgA_usAeL6tsLp9T6XI1VYYoFy08Xn-ew,2836
|
122
126
|
flwr/server/strategy/aggregate.py,sha256=QyRIJtI5gnuY1NbgrcrOvkHxGIxBvApq7d9Y4xl-6W4,13468
|
123
127
|
flwr/server/strategy/bulyan.py,sha256=8GsSVJzRSoSWE2zQUKqC3Z795grdN9xpmc3MSGGXnzM,6532
|
124
|
-
flwr/server/strategy/dp_adaptive_clipping.py,sha256=
|
128
|
+
flwr/server/strategy/dp_adaptive_clipping.py,sha256=BVvX1LivyukvEtOZVZnpgVpzH8BBjvA3OmdGwFxgRuQ,16679
|
125
129
|
flwr/server/strategy/dp_fixed_clipping.py,sha256=v9YyX53jt2RatGnFxTK4ZMO_3SN7EdL9YCaaJtn9Fcc,12125
|
126
130
|
flwr/server/strategy/dpfedavg_adaptive.py,sha256=hLJkPQJl1bHjwrBNg3PSRFKf3no0hg5EHiFaWhHlWqw,4877
|
127
131
|
flwr/server/strategy/dpfedavg_fixed.py,sha256=G0yYxrPoM-MHQ889DYN3OeNiEeU0yQrjgAzcq0G653w,7219
|
@@ -145,7 +149,7 @@ flwr/server/strategy/strategy.py,sha256=g6VoIFogEviRub6G4QsKdIp6M_Ek6GhBhqcdNx5u
|
|
145
149
|
flwr/server/superlink/__init__.py,sha256=8tHYCfodUlRD8PCP9fHgvu8cz5N31A2QoRVL0jDJ15E,707
|
146
150
|
flwr/server/superlink/driver/__init__.py,sha256=STB1_DASVEg7Cu6L7VYxTzV7UMkgtBkFim09Z82Dh8I,712
|
147
151
|
flwr/server/superlink/driver/driver_grpc.py,sha256=1qSGDs1k_OVPWxp2ofxvQgtYXExrMeC3N_rNPVWH65M,1932
|
148
|
-
flwr/server/superlink/driver/driver_servicer.py,sha256=
|
152
|
+
flwr/server/superlink/driver/driver_servicer.py,sha256=p1rlocgqU2I4s5IwdU8rTZBkQ73yPmuWtK_aUlB7V84,4573
|
149
153
|
flwr/server/superlink/fleet/__init__.py,sha256=C6GCSD5eP5Of6_dIeSe1jx9HnV0icsvWyQ5EKAUHJRU,711
|
150
154
|
flwr/server/superlink/fleet/grpc_bidi/__init__.py,sha256=mgGJGjwT6VU7ovC1gdnnqttjyBPlNIcZnYRqx4K3IBQ,735
|
151
155
|
flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py,sha256=57b3UL5-baGdLwgCtB0dCUTTSbmmfMAXcXV5bjPZNWQ,5993
|
@@ -174,14 +178,15 @@ flwr/server/utils/tensorboard.py,sha256=k0G6bqsLx7wfYbH2KtXsDYcOCfyIeE12-hefXA7l
|
|
174
178
|
flwr/server/utils/validator.py,sha256=IJN2475yyD_i_9kg_SJ_JodIuZh58ufpWGUDQRAqu2s,4740
|
175
179
|
flwr/server/workflow/__init__.py,sha256=2YrKq5wUwge8Tm1xaAdf2P3l4LbM4olka6tO_0_Mu9A,787
|
176
180
|
flwr/server/workflow/default_workflows.py,sha256=DKSt14WY5m19ujwh6UDP4a31kRs-j6V_NZm5cXn01ZY,12705
|
177
|
-
flwr/simulation/__init__.py,sha256=
|
181
|
+
flwr/simulation/__init__.py,sha256=jdrJeTnLLj9Eyl8BRPXMewqkhTnxD7fvXDgyjfspy0Q,1359
|
178
182
|
flwr/simulation/app.py,sha256=WqJxdXTEuehwMW605p5NMmvBbKYx5tuqnV3Mp7jSWXM,13904
|
179
183
|
flwr/simulation/ray_transport/__init__.py,sha256=FsaAnzC4cw4DqoouBCix6496k29jACkfeIam55BvW9g,734
|
180
184
|
flwr/simulation/ray_transport/ray_actor.py,sha256=zRETW_xuCAOLRFaYnQ-q3IBSz0LIv_0RifGuhgWaYOg,19872
|
181
185
|
flwr/simulation/ray_transport/ray_client_proxy.py,sha256=DpmrBC87_sX3J4WrrwzyEDIjONUeliBZx9T-gZGuPmQ,6799
|
182
186
|
flwr/simulation/ray_transport/utils.py,sha256=TYdtfg1P9VfTdLMOJlifInGpxWHYs9UfUqIv2wfkRLA,2392
|
183
|
-
|
184
|
-
flwr_nightly-1.8.0.
|
185
|
-
flwr_nightly-1.8.0.
|
186
|
-
flwr_nightly-1.8.0.
|
187
|
-
flwr_nightly-1.8.0.
|
187
|
+
flwr/simulation/run_simulation.py,sha256=NYUFJ6cG5QtuwEl6f2IWJNMRo_jDmmA41DABqISpq-Q,5950
|
188
|
+
flwr_nightly-1.8.0.dev20240229.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
189
|
+
flwr_nightly-1.8.0.dev20240229.dist-info/METADATA,sha256=bnkE4rD-5EY00NUOdT-ZzoeuextoRXX3zZL0MLR6PbU,15184
|
190
|
+
flwr_nightly-1.8.0.dev20240229.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
|
191
|
+
flwr_nightly-1.8.0.dev20240229.dist-info/entry_points.txt,sha256=qz6t0YdMrV_PLbEarJ6ITIWSIRrTbG_jossZAYfXBZQ,311
|
192
|
+
flwr_nightly-1.8.0.dev20240229.dist-info/RECORD,,
|
@@ -3,6 +3,7 @@ flower-client-app=flwr.client:run_client_app
|
|
3
3
|
flower-driver-api=flwr.server:run_driver_api
|
4
4
|
flower-fleet-api=flwr.server:run_fleet_api
|
5
5
|
flower-server-app=flwr.server:run_server_app
|
6
|
+
flower-simulation=flwr.simulation:run_simulation
|
6
7
|
flower-superlink=flwr.server:run_superlink
|
7
8
|
flwr=flwr.cli.app:app
|
8
9
|
|
{flwr_nightly-1.8.0.dev20240228.dist-info → flwr_nightly-1.8.0.dev20240229.dist-info}/LICENSE
RENAMED
File without changes
|
File without changes
|