flwr 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +17 -1
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +95 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
- flwr/cli/build.py +118 -47
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +34 -23
- flwr/cli/ls.py +13 -9
- flwr/cli/new/new.py +196 -42
- flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
- flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
- flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
- flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
- flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
- flwr/cli/pull.py +100 -0
- flwr/cli/run/run.py +11 -7
- flwr/cli/stop.py +2 -2
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +260 -0
- flwr/cli/supernode/register.py +185 -0
- flwr/cli/supernode/unregister.py +138 -0
- flwr/cli/utils.py +109 -69
- flwr/client/__init__.py +2 -1
- flwr/client/grpc_adapter_client/connection.py +6 -8
- flwr/client/grpc_rere_client/connection.py +59 -31
- flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
- flwr/client/rest_client/connection.py +82 -37
- flwr/clientapp/__init__.py +1 -2
- flwr/clientapp/mod/__init__.py +4 -1
- flwr/clientapp/mod/centraldp_mods.py +156 -40
- flwr/clientapp/mod/localdp_mod.py +169 -0
- flwr/clientapp/typing.py +22 -0
- flwr/{client/clientapp → clientapp}/utils.py +1 -1
- flwr/common/constant.py +56 -13
- flwr/common/exit/exit_code.py +24 -10
- flwr/common/inflatable_utils.py +10 -10
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +10 -1
- flwr/common/record/typeddict.py +12 -0
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/serde.py +4 -2
- flwr/common/typing.py +7 -6
- flwr/compat/client/app.py +1 -1
- flwr/compat/client/grpc_client/connection.py +2 -2
- flwr/proto/control_pb2.py +48 -31
- flwr/proto/control_pb2.pyi +95 -5
- flwr/proto/control_pb2_grpc.py +136 -0
- flwr/proto/control_pb2_grpc.pyi +52 -0
- flwr/proto/fab_pb2.py +11 -7
- flwr/proto/fab_pb2.pyi +21 -1
- flwr/proto/fleet_pb2.py +31 -23
- flwr/proto/fleet_pb2.pyi +63 -23
- flwr/proto/fleet_pb2_grpc.py +98 -28
- flwr/proto/fleet_pb2_grpc.pyi +45 -13
- flwr/proto/node_pb2.py +3 -1
- flwr/proto/node_pb2.pyi +48 -0
- flwr/server/app.py +152 -114
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
- flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +18 -5
- flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
- flwr/server/superlink/linkstate/linkstate.py +107 -24
- flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
- flwr/server/superlink/linkstate/utils.py +3 -54
- flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
- flwr/serverapp/strategy/__init__.py +26 -0
- flwr/serverapp/strategy/bulyan.py +238 -0
- flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
- flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
- flwr/serverapp/strategy/fedadagrad.py +0 -3
- flwr/serverapp/strategy/fedadam.py +0 -3
- flwr/serverapp/strategy/fedavg.py +89 -64
- flwr/serverapp/strategy/fedavgm.py +198 -0
- flwr/serverapp/strategy/fedmedian.py +105 -0
- flwr/serverapp/strategy/fedprox.py +174 -0
- flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
- flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
- flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
- flwr/serverapp/strategy/fedyogi.py +0 -3
- flwr/serverapp/strategy/krum.py +112 -0
- flwr/serverapp/strategy/multikrum.py +247 -0
- flwr/serverapp/strategy/qfedavg.py +252 -0
- flwr/serverapp/strategy/strategy_utils.py +48 -0
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +28 -32
- flwr/supercore/cli/flower_superexec.py +26 -1
- flwr/supercore/constant.py +41 -0
- flwr/supercore/object_store/in_memory_object_store.py +0 -4
- flwr/supercore/object_store/object_store_factory.py +26 -6
- flwr/supercore/object_store/sqlite_object_store.py +252 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
- flwr/supercore/sqlite_mixin.py +156 -0
- flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
- flwr/supercore/superexec/run_superexec.py +16 -2
- flwr/supercore/utils.py +20 -0
- flwr/superlink/artifact_provider/__init__.py +22 -0
- flwr/superlink/artifact_provider/artifact_provider.py +37 -0
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +91 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
- flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
- flwr/superlink/servicer/control/control_grpc.py +16 -11
- flwr/superlink/servicer/control/control_servicer.py +207 -58
- flwr/supernode/cli/flower_supernode.py +19 -26
- flwr/supernode/runtime/run_clientapp.py +2 -2
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
- flwr/supernode/start_client_internal.py +17 -9
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/METADATA +6 -16
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/RECORD +170 -140
- flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
- flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- flwr/serverapp/dp_fixed_clipping.py +0 -352
- flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
- /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
- /flwr/{client → clientapp}/client_app.py +0 -0
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
- {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
|
@@ -33,6 +33,7 @@ from flwr.common.typing import RunStatus
|
|
|
33
33
|
# pylint: disable=E0611
|
|
34
34
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
35
35
|
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
|
36
|
+
from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
|
|
36
37
|
|
|
37
38
|
# pylint: enable=E0611
|
|
38
39
|
VALID_RUN_STATUS_TRANSITIONS = {
|
|
@@ -76,58 +77,6 @@ def generate_rand_int_from_bytes(
|
|
|
76
77
|
return num
|
|
77
78
|
|
|
78
79
|
|
|
79
|
-
def convert_uint64_to_sint64(u: int) -> int:
|
|
80
|
-
"""Convert a uint64 value to a sint64 value with the same bit sequence.
|
|
81
|
-
|
|
82
|
-
Parameters
|
|
83
|
-
----------
|
|
84
|
-
u : int
|
|
85
|
-
The unsigned 64-bit integer to convert.
|
|
86
|
-
|
|
87
|
-
Returns
|
|
88
|
-
-------
|
|
89
|
-
int
|
|
90
|
-
The signed 64-bit integer equivalent.
|
|
91
|
-
|
|
92
|
-
The signed 64-bit integer will have the same bit pattern as the
|
|
93
|
-
unsigned 64-bit integer but may have a different decimal value.
|
|
94
|
-
|
|
95
|
-
For numbers within the range [0, `sint64` max value], the decimal
|
|
96
|
-
value remains the same. However, for numbers greater than the `sint64`
|
|
97
|
-
max value, the decimal value will differ due to the wraparound caused
|
|
98
|
-
by the sign bit.
|
|
99
|
-
"""
|
|
100
|
-
if u >= (1 << 63):
|
|
101
|
-
return u - (1 << 64)
|
|
102
|
-
return u
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
def convert_sint64_to_uint64(s: int) -> int:
|
|
106
|
-
"""Convert a sint64 value to a uint64 value with the same bit sequence.
|
|
107
|
-
|
|
108
|
-
Parameters
|
|
109
|
-
----------
|
|
110
|
-
s : int
|
|
111
|
-
The signed 64-bit integer to convert.
|
|
112
|
-
|
|
113
|
-
Returns
|
|
114
|
-
-------
|
|
115
|
-
int
|
|
116
|
-
The unsigned 64-bit integer equivalent.
|
|
117
|
-
|
|
118
|
-
The unsigned 64-bit integer will have the same bit pattern as the
|
|
119
|
-
signed 64-bit integer but may have a different decimal value.
|
|
120
|
-
|
|
121
|
-
For negative `sint64` values, the conversion adds 2^64 to the
|
|
122
|
-
signed value to obtain the equivalent `uint64` value. For non-negative
|
|
123
|
-
`sint64` values, the decimal value remains unchanged in the `uint64`
|
|
124
|
-
representation.
|
|
125
|
-
"""
|
|
126
|
-
if s < 0:
|
|
127
|
-
return s + (1 << 64)
|
|
128
|
-
return s
|
|
129
|
-
|
|
130
|
-
|
|
131
80
|
def convert_uint64_values_in_dict_to_sint64(
|
|
132
81
|
data_dict: dict[str, int], keys: list[str]
|
|
133
82
|
) -> None:
|
|
@@ -142,7 +91,7 @@ def convert_uint64_values_in_dict_to_sint64(
|
|
|
142
91
|
"""
|
|
143
92
|
for key in keys:
|
|
144
93
|
if key in data_dict:
|
|
145
|
-
data_dict[key] =
|
|
94
|
+
data_dict[key] = uint64_to_int64(data_dict[key])
|
|
146
95
|
|
|
147
96
|
|
|
148
97
|
def convert_sint64_values_in_dict_to_uint64(
|
|
@@ -159,7 +108,7 @@ def convert_sint64_values_in_dict_to_uint64(
|
|
|
159
108
|
"""
|
|
160
109
|
for key in keys:
|
|
161
110
|
if key in data_dict:
|
|
162
|
-
data_dict[key] =
|
|
111
|
+
data_dict[key] = int64_to_uint64(data_dict[key])
|
|
163
112
|
|
|
164
113
|
|
|
165
114
|
def context_to_bytes(context: Context) -> bytes:
|
|
@@ -316,7 +316,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
316
316
|
|
|
317
317
|
ffs: Ffs = self.ffs_factory.ffs()
|
|
318
318
|
if result := ffs.get(request.hash_str):
|
|
319
|
-
fab = Fab(request.hash_str, result[0])
|
|
319
|
+
fab = Fab(request.hash_str, result[0], result[1])
|
|
320
320
|
return GetFabResponse(fab=fab_to_proto(fab))
|
|
321
321
|
|
|
322
322
|
raise ValueError(f"Found no FAB with hash: {request.hash_str}")
|
|
@@ -343,7 +343,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
|
343
343
|
fab = None
|
|
344
344
|
if run and run.fab_hash:
|
|
345
345
|
if result := ffs.get(run.fab_hash):
|
|
346
|
-
fab = Fab(run.fab_hash, result[0])
|
|
346
|
+
fab = Fab(run.fab_hash, result[0], result[1])
|
|
347
347
|
if run and fab and serverapp_ctxt:
|
|
348
348
|
# Update run status to STARTING
|
|
349
349
|
if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
|
|
@@ -150,7 +150,7 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
|
|
|
150
150
|
fab = None
|
|
151
151
|
if run and run.fab_hash:
|
|
152
152
|
if result := ffs.get(run.fab_hash):
|
|
153
|
-
fab = Fab(run.fab_hash, result[0])
|
|
153
|
+
fab = Fab(run.fab_hash, result[0], result[1])
|
|
154
154
|
if run and fab and serverapp_ctxt:
|
|
155
155
|
# Update run status to STARTING
|
|
156
156
|
if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
|
flwr/server/utils/validator.py
CHANGED
|
@@ -15,10 +15,9 @@
|
|
|
15
15
|
"""Validators."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import time
|
|
19
|
-
|
|
20
18
|
from flwr.common import Message
|
|
21
19
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
|
20
|
+
from flwr.common.date import now
|
|
22
21
|
|
|
23
22
|
|
|
24
23
|
# pylint: disable-next=too-many-branches
|
|
@@ -44,7 +43,7 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
|
44
43
|
validation_errors.append("`metadata.ttl` must be higher than zero")
|
|
45
44
|
|
|
46
45
|
# Verify TTL and created_at time
|
|
47
|
-
current_time =
|
|
46
|
+
current_time = now().timestamp()
|
|
48
47
|
if metadata.created_at + metadata.ttl <= current_time:
|
|
49
48
|
validation_errors.append("Message TTL has expired")
|
|
50
49
|
|
|
@@ -35,8 +35,6 @@ from flwr.common import (
|
|
|
35
35
|
)
|
|
36
36
|
from flwr.common.secure_aggregation.crypto.shamir import combine_shares
|
|
37
37
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
38
|
-
bytes_to_private_key,
|
|
39
|
-
bytes_to_public_key,
|
|
40
38
|
generate_shared_key,
|
|
41
39
|
)
|
|
42
40
|
from flwr.common.secure_aggregation.ndarrays_arithmetic import (
|
|
@@ -56,6 +54,10 @@ from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
|
|
|
56
54
|
from flwr.server.client_proxy import ClientProxy
|
|
57
55
|
from flwr.server.compat.legacy_context import LegacyContext
|
|
58
56
|
from flwr.server.grid import Grid
|
|
57
|
+
from flwr.supercore.primitives.asymmetric import (
|
|
58
|
+
bytes_to_private_key,
|
|
59
|
+
bytes_to_public_key,
|
|
60
|
+
)
|
|
59
61
|
|
|
60
62
|
from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
|
|
61
63
|
from ..constant import Key as WorkflowKey
|
|
@@ -15,6 +15,11 @@
|
|
|
15
15
|
"""ServerApp strategies."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from .bulyan import Bulyan
|
|
19
|
+
from .dp_adaptive_clipping import (
|
|
20
|
+
DifferentialPrivacyClientSideAdaptiveClipping,
|
|
21
|
+
DifferentialPrivacyServerSideAdaptiveClipping,
|
|
22
|
+
)
|
|
18
23
|
from .dp_fixed_clipping import (
|
|
19
24
|
DifferentialPrivacyClientSideFixedClipping,
|
|
20
25
|
DifferentialPrivacyServerSideFixedClipping,
|
|
@@ -22,17 +27,38 @@ from .dp_fixed_clipping import (
|
|
|
22
27
|
from .fedadagrad import FedAdagrad
|
|
23
28
|
from .fedadam import FedAdam
|
|
24
29
|
from .fedavg import FedAvg
|
|
30
|
+
from .fedavgm import FedAvgM
|
|
31
|
+
from .fedmedian import FedMedian
|
|
32
|
+
from .fedprox import FedProx
|
|
33
|
+
from .fedtrimmedavg import FedTrimmedAvg
|
|
34
|
+
from .fedxgb_bagging import FedXgbBagging
|
|
35
|
+
from .fedxgb_cyclic import FedXgbCyclic
|
|
25
36
|
from .fedyogi import FedYogi
|
|
37
|
+
from .krum import Krum
|
|
38
|
+
from .multikrum import MultiKrum
|
|
39
|
+
from .qfedavg import QFedAvg
|
|
26
40
|
from .result import Result
|
|
27
41
|
from .strategy import Strategy
|
|
28
42
|
|
|
29
43
|
__all__ = [
|
|
44
|
+
"Bulyan",
|
|
45
|
+
"DifferentialPrivacyClientSideAdaptiveClipping",
|
|
30
46
|
"DifferentialPrivacyClientSideFixedClipping",
|
|
47
|
+
"DifferentialPrivacyServerSideAdaptiveClipping",
|
|
31
48
|
"DifferentialPrivacyServerSideFixedClipping",
|
|
32
49
|
"FedAdagrad",
|
|
33
50
|
"FedAdam",
|
|
34
51
|
"FedAvg",
|
|
52
|
+
"FedAvgM",
|
|
53
|
+
"FedMedian",
|
|
54
|
+
"FedProx",
|
|
55
|
+
"FedTrimmedAvg",
|
|
56
|
+
"FedXgbBagging",
|
|
57
|
+
"FedXgbCyclic",
|
|
35
58
|
"FedYogi",
|
|
59
|
+
"Krum",
|
|
60
|
+
"MultiKrum",
|
|
61
|
+
"QFedAvg",
|
|
36
62
|
"Result",
|
|
37
63
|
"Strategy",
|
|
38
64
|
]
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Bulyan [El Mhamdi et al., 2018] strategy.
|
|
16
|
+
|
|
17
|
+
Paper: arxiv.org/abs/1802.07927
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
from collections import OrderedDict
|
|
22
|
+
from collections.abc import Iterable
|
|
23
|
+
from logging import INFO, WARN
|
|
24
|
+
from typing import Callable, Optional, cast
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
|
|
28
|
+
from flwr.common import (
|
|
29
|
+
Array,
|
|
30
|
+
ArrayRecord,
|
|
31
|
+
Message,
|
|
32
|
+
MetricRecord,
|
|
33
|
+
NDArrays,
|
|
34
|
+
RecordDict,
|
|
35
|
+
log,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from .fedavg import FedAvg
|
|
39
|
+
from .multikrum import select_multikrum
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# pylint: disable=too-many-instance-attributes
|
|
43
|
+
class Bulyan(FedAvg):
|
|
44
|
+
"""Bulyan strategy.
|
|
45
|
+
|
|
46
|
+
Implementation based on https://arxiv.org/abs/1802.07927.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
fraction_train : float (default: 1.0)
|
|
51
|
+
Fraction of nodes used during training. In case `min_train_nodes`
|
|
52
|
+
is larger than `fraction_train * total_connected_nodes`, `min_train_nodes`
|
|
53
|
+
will still be sampled.
|
|
54
|
+
fraction_evaluate : float (default: 1.0)
|
|
55
|
+
Fraction of nodes used during validation. In case `min_evaluate_nodes`
|
|
56
|
+
is larger than `fraction_evaluate * total_connected_nodes`,
|
|
57
|
+
`min_evaluate_nodes` will still be sampled.
|
|
58
|
+
min_train_nodes : int (default: 2)
|
|
59
|
+
Minimum number of nodes used during training.
|
|
60
|
+
min_evaluate_nodes : int (default: 2)
|
|
61
|
+
Minimum number of nodes used during validation.
|
|
62
|
+
min_available_nodes : int (default: 2)
|
|
63
|
+
Minimum number of total nodes in the system.
|
|
64
|
+
num_malicious_nodes : int (default: 0)
|
|
65
|
+
Number of malicious nodes in the system.
|
|
66
|
+
weighted_by_key : str (default: "num-examples")
|
|
67
|
+
The key within each MetricRecord whose value is used as the weight when
|
|
68
|
+
computing weighted averages for MetricRecords.
|
|
69
|
+
arrayrecord_key : str (default: "arrays")
|
|
70
|
+
Key used to store the ArrayRecord when constructing Messages.
|
|
71
|
+
configrecord_key : str (default: "config")
|
|
72
|
+
Key used to store the ConfigRecord when constructing Messages.
|
|
73
|
+
train_metrics_aggr_fn : Optional[callable] (default: None)
|
|
74
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
|
75
|
+
used to aggregate MetricRecords from training round replies.
|
|
76
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
|
77
|
+
average using the provided weight factor key.
|
|
78
|
+
evaluate_metrics_aggr_fn : Optional[callable] (default: None)
|
|
79
|
+
Function with signature (list[RecordDict], str) -> MetricRecord,
|
|
80
|
+
used to aggregate MetricRecords from training round replies.
|
|
81
|
+
If `None`, defaults to `aggregate_metricrecords`, which performs a weighted
|
|
82
|
+
average using the provided weight factor key.
|
|
83
|
+
selection_rule : Optional[Callable] (default: None)
|
|
84
|
+
Function with signature (list[RecordDict], int, int) -> list[RecordDict].
|
|
85
|
+
The inputs are:
|
|
86
|
+
- a list of contents from reply messages,
|
|
87
|
+
- the assumed number of malicious nodes (`num_malicious_nodes`),
|
|
88
|
+
- the number of nodes to select (`num_nodes_to_select`).
|
|
89
|
+
|
|
90
|
+
The function should implement a Byzantine-resilient selection rule that
|
|
91
|
+
serves as the first step of Bulyan. If None, defaults to `select_multikrum`,
|
|
92
|
+
which selects nodes according to the Multi-Krum algorithm.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
fraction_train: float = 1.0,
|
|
99
|
+
fraction_evaluate: float = 1.0,
|
|
100
|
+
min_train_nodes: int = 2,
|
|
101
|
+
min_evaluate_nodes: int = 2,
|
|
102
|
+
min_available_nodes: int = 2,
|
|
103
|
+
num_malicious_nodes: int = 0,
|
|
104
|
+
weighted_by_key: str = "num-examples",
|
|
105
|
+
arrayrecord_key: str = "arrays",
|
|
106
|
+
configrecord_key: str = "config",
|
|
107
|
+
train_metrics_aggr_fn: Optional[
|
|
108
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
|
109
|
+
] = None,
|
|
110
|
+
evaluate_metrics_aggr_fn: Optional[
|
|
111
|
+
Callable[[list[RecordDict], str], MetricRecord]
|
|
112
|
+
] = None,
|
|
113
|
+
selection_rule: Optional[
|
|
114
|
+
Callable[[list[RecordDict], int, int], list[RecordDict]]
|
|
115
|
+
] = None,
|
|
116
|
+
) -> None:
|
|
117
|
+
super().__init__(
|
|
118
|
+
fraction_train=fraction_train,
|
|
119
|
+
fraction_evaluate=fraction_evaluate,
|
|
120
|
+
min_train_nodes=min_train_nodes,
|
|
121
|
+
min_evaluate_nodes=min_evaluate_nodes,
|
|
122
|
+
min_available_nodes=min_available_nodes,
|
|
123
|
+
weighted_by_key=weighted_by_key,
|
|
124
|
+
arrayrecord_key=arrayrecord_key,
|
|
125
|
+
configrecord_key=configrecord_key,
|
|
126
|
+
train_metrics_aggr_fn=train_metrics_aggr_fn,
|
|
127
|
+
evaluate_metrics_aggr_fn=evaluate_metrics_aggr_fn,
|
|
128
|
+
)
|
|
129
|
+
self.num_malicious_nodes = num_malicious_nodes
|
|
130
|
+
self.selection_rule = selection_rule or select_multikrum
|
|
131
|
+
|
|
132
|
+
def summary(self) -> None:
|
|
133
|
+
"""Log summary configuration of the strategy."""
|
|
134
|
+
log(INFO, "\t├──> Bulyan settings:")
|
|
135
|
+
log(INFO, "\t│\t├── Number of malicious nodes: %d", self.num_malicious_nodes)
|
|
136
|
+
log(INFO, "\t│\t└── Selection rule: %s", self.selection_rule.__name__)
|
|
137
|
+
super().summary()
|
|
138
|
+
|
|
139
|
+
def aggregate_train(
|
|
140
|
+
self,
|
|
141
|
+
server_round: int,
|
|
142
|
+
replies: Iterable[Message],
|
|
143
|
+
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
|
144
|
+
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
145
|
+
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
146
|
+
|
|
147
|
+
# Check if sufficient replies have been received
|
|
148
|
+
if len(valid_replies) < 4 * self.num_malicious_nodes + 3:
|
|
149
|
+
log(
|
|
150
|
+
WARN,
|
|
151
|
+
"Insufficient replies, skipping Bulyan aggregation: "
|
|
152
|
+
"Required at least %d (4*num_malicious_nodes + 3), but received %d.",
|
|
153
|
+
4 * self.num_malicious_nodes + 3,
|
|
154
|
+
len(valid_replies),
|
|
155
|
+
)
|
|
156
|
+
return None, None
|
|
157
|
+
|
|
158
|
+
reply_contents = [msg.content for msg in valid_replies]
|
|
159
|
+
|
|
160
|
+
# Compute theta and beta
|
|
161
|
+
theta = len(valid_replies) - 2 * self.num_malicious_nodes
|
|
162
|
+
beta = theta - 2 * self.num_malicious_nodes
|
|
163
|
+
|
|
164
|
+
# Byzantine-resilient selection rule
|
|
165
|
+
selected_contents = self.selection_rule(
|
|
166
|
+
reply_contents, self.num_malicious_nodes, theta
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Convert each ArrayRecord to a list of NDArray for easier computation
|
|
170
|
+
key = list(selected_contents[0].array_records.keys())[0]
|
|
171
|
+
array_keys = list(selected_contents[0][key].keys())
|
|
172
|
+
selected_ndarrays = [
|
|
173
|
+
cast(ArrayRecord, ctnt[key]).to_numpy_ndarrays(keep_input=False)
|
|
174
|
+
for ctnt in selected_contents
|
|
175
|
+
]
|
|
176
|
+
|
|
177
|
+
# Compute median
|
|
178
|
+
median_ndarrays = [np.median(arr, axis=0) for arr in zip(*selected_ndarrays)]
|
|
179
|
+
|
|
180
|
+
# Aggregate the beta closest weights element-wise
|
|
181
|
+
aggregated_ndarrays = aggregate_n_closest_weights(
|
|
182
|
+
median_ndarrays, selected_ndarrays, beta
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Convert to ArrayRecord
|
|
186
|
+
arrays = ArrayRecord(
|
|
187
|
+
OrderedDict(zip(array_keys, map(Array, aggregated_ndarrays)))
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Aggregate MetricRecords
|
|
191
|
+
metrics = self.train_metrics_aggr_fn(
|
|
192
|
+
selected_contents,
|
|
193
|
+
self.weighted_by_key,
|
|
194
|
+
)
|
|
195
|
+
return arrays, metrics
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def aggregate_n_closest_weights(
|
|
199
|
+
ref_weights: NDArrays, weights_list: list[NDArrays], beta: int
|
|
200
|
+
) -> NDArrays:
|
|
201
|
+
"""Compute the element-wise mean of the `beta` closest weight arrays.
|
|
202
|
+
|
|
203
|
+
For each element (i-th coordinate), the output is the average of the
|
|
204
|
+
`beta` weight arrays that are closest to the reference weights.
|
|
205
|
+
|
|
206
|
+
Parameters
|
|
207
|
+
----------
|
|
208
|
+
ref_weights : NDArrays
|
|
209
|
+
Reference weights used to compute distances.
|
|
210
|
+
weights_list : list[NDArrays]
|
|
211
|
+
List of weight arrays (e.g., from selected nodes).
|
|
212
|
+
beta : int
|
|
213
|
+
Number of closest weight arrays to include in the averaging.
|
|
214
|
+
|
|
215
|
+
Returns
|
|
216
|
+
-------
|
|
217
|
+
aggregated_weights : NDArrays
|
|
218
|
+
Element-wise average of the `beta` closest weight arrays to the
|
|
219
|
+
reference weights.
|
|
220
|
+
"""
|
|
221
|
+
aggregated_weights = []
|
|
222
|
+
for layer_id, ref_layer in enumerate(ref_weights):
|
|
223
|
+
# Shape: (n_models, *layer_shape)
|
|
224
|
+
layer_stack = np.stack([weights[layer_id] for weights in weights_list])
|
|
225
|
+
|
|
226
|
+
# Compute absolute differences: shape (n_models, *layer_shape)
|
|
227
|
+
diffs = np.abs(layer_stack - ref_layer)
|
|
228
|
+
|
|
229
|
+
# Find indices of `beta` smallest per coordinate
|
|
230
|
+
idx = np.argpartition(diffs, beta - 1, axis=0)[:beta]
|
|
231
|
+
|
|
232
|
+
# Gather the closest weights
|
|
233
|
+
closest = np.take_along_axis(layer_stack, idx, axis=0)
|
|
234
|
+
|
|
235
|
+
# Average them
|
|
236
|
+
aggregated_weights.append(np.mean(closest, axis=0))
|
|
237
|
+
|
|
238
|
+
return aggregated_weights
|