flwr 1.22.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +34 -1
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +94 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +101 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +46 -32
- flwr/cli/build.py +166 -53
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +29 -11
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +54 -11
- flwr/cli/login/login.py +41 -27
- flwr/cli/ls.py +177 -133
- flwr/cli/new/new.py +175 -40
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +12 -7
- flwr/cli/run/run.py +82 -31
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +27 -9
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +268 -0
- flwr/cli/supernode/register.py +190 -0
- flwr/cli/supernode/unregister.py +140 -0
- flwr/cli/utils.py +464 -81
- flwr/client/__init__.py +2 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +12 -15
- flwr/client/grpc_rere_client/connection.py +68 -41
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -14
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +5 -7
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +10 -8
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +94 -51
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/__init__.py +1 -2
- flwr/{client → clientapp}/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/{client/clientapp → clientapp}/utils.py +4 -4
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +56 -13
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +39 -10
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +48 -31
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +6 -6
- flwr/common/record/arrayrecord.py +18 -21
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +9 -6
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +59 -43
- flwr/compat/client/app.py +39 -38
- flwr/compat/client/grpc_client/connection.py +13 -13
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +72 -40
- flwr/proto/control_pb2.pyi +319 -87
- flwr/proto/control_pb2_grpc.py +339 -28
- flwr/proto/control_pb2_grpc.pyi +209 -37
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +24 -10
- flwr/proto/fab_pb2.pyi +68 -20
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +45 -27
- flwr/proto/fleet_pb2.pyi +186 -70
- flwr/proto/fleet_pb2_grpc.py +277 -66
- flwr/proto/fleet_pb2_grpc.pyi +201 -55
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +16 -4
- flwr/proto/node_pb2.pyi +77 -4
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +173 -127
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +136 -42
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +28 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +100 -49
- flwr/server/superlink/fleet/rest_rere/rest_api.py +54 -33
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +6 -6
- flwr/server/superlink/fleet/vce/vce_api.py +32 -13
- flwr/server/superlink/linkstate/in_memory_linkstate.py +266 -207
- flwr/server/superlink/linkstate/linkstate.py +161 -62
- flwr/server/superlink/linkstate/linkstate_factory.py +24 -6
- flwr/server/superlink/linkstate/sqlite_linkstate.py +698 -638
- flwr/server/superlink/linkstate/utils.py +9 -60
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +28 -23
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +19 -14
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +12 -10
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +11 -12
- flwr/simulation/ray_transport/ray_client_proxy.py +12 -13
- flwr/simulation/run_simulation.py +44 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +52 -0
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -6
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +27 -8
- flwr/supercore/object_store/sqlite_object_store.py +253 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +175 -0
- flwr/supercore/sqlite_mixin.py +159 -0
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/supercore/utils.py +20 -0
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +88 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +84 -0
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +41 -32
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +18 -17
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +239 -63
- flwr/supernode/cli/flower_supernode.py +74 -26
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +43 -24
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +175 -51
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.22.0.dist-info/RECORD +0 -428
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.22.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
flwr/client/__init__.py
CHANGED
|
@@ -15,10 +15,11 @@
|
|
|
15
15
|
"""Flower client."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from flwr.clientapp.client_app import ClientApp
|
|
19
|
+
|
|
18
20
|
from ..compat.client.app import start_client as start_client # Deprecated
|
|
19
21
|
from ..compat.client.app import start_numpy_client as start_numpy_client # Deprecated
|
|
20
22
|
from .client import Client as Client
|
|
21
|
-
from .client_app import ClientApp as ClientApp
|
|
22
23
|
from .numpy_client import NumPyClient as NumPyClient
|
|
23
24
|
from .typing import ClientFn as ClientFn
|
|
24
25
|
from .typing import ClientFnExt as ClientFnExt
|
|
@@ -120,7 +120,10 @@ class DPFedAvgNumPyClient(NumPyClient):
|
|
|
120
120
|
updated_params, num_examples, metrics = self.client.fit(parameters, config)
|
|
121
121
|
|
|
122
122
|
# Update = updated model - original model
|
|
123
|
-
update = [
|
|
123
|
+
update = [
|
|
124
|
+
np.subtract(x, y)
|
|
125
|
+
for (x, y) in zip(updated_params, original_params, strict=True)
|
|
126
|
+
]
|
|
124
127
|
|
|
125
128
|
if "dpfedavg_clip_norm" not in config:
|
|
126
129
|
raise KeyError("Clipping threshold not supplied by the server.")
|
|
@@ -15,10 +15,9 @@
|
|
|
15
15
|
"""Contextmanager for a GrpcAdapter channel to the Flower server."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from collections.abc import Iterator
|
|
18
|
+
from collections.abc import Callable, Iterator
|
|
19
19
|
from contextlib import contextmanager
|
|
20
20
|
from logging import ERROR
|
|
21
|
-
from typing import Callable, Optional, Union
|
|
22
21
|
|
|
23
22
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
23
|
|
|
@@ -38,16 +37,15 @@ def grpc_adapter( # pylint: disable=R0913,too-many-positional-arguments
|
|
|
38
37
|
insecure: bool,
|
|
39
38
|
retry_invoker: RetryInvoker,
|
|
40
39
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
|
41
|
-
root_certificates:
|
|
42
|
-
authentication_keys:
|
|
43
|
-
tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
44
|
-
|
|
40
|
+
root_certificates: bytes | str | None = None,
|
|
41
|
+
authentication_keys: (
|
|
42
|
+
tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] | None
|
|
43
|
+
) = None,
|
|
45
44
|
) -> Iterator[
|
|
46
45
|
tuple[
|
|
47
|
-
|
|
46
|
+
int,
|
|
47
|
+
Callable[[], tuple[Message, ObjectTree] | None],
|
|
48
48
|
Callable[[Message, ObjectTree], set[str]],
|
|
49
|
-
Callable[[], Optional[int]],
|
|
50
|
-
Callable[[], None],
|
|
51
49
|
Callable[[int], Run],
|
|
52
50
|
Callable[[str, int], Fab],
|
|
53
51
|
Callable[[int, str], bytes],
|
|
@@ -77,22 +75,21 @@ def grpc_adapter( # pylint: disable=R0913,too-many-positional-arguments
|
|
|
77
75
|
connection using the certificates will be established to an SSL-enabled
|
|
78
76
|
Flower server. Bytes won't work for the REST API.
|
|
79
77
|
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
|
|
80
|
-
|
|
78
|
+
SuperNode authentication is not supported for this transport type.
|
|
81
79
|
|
|
82
80
|
Returns
|
|
83
81
|
-------
|
|
82
|
+
node_id : int
|
|
84
83
|
receive : Callable[[], Optional[tuple[Message, ObjectTree]]]
|
|
85
84
|
send : Callable[[Message, ObjectTree], set[str]]
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
get_run : Optional[Callable]
|
|
89
|
-
get_fab : Optional[Callable]
|
|
85
|
+
get_run : Callable[[int], Run]
|
|
86
|
+
get_fab : Callable[[str, int], Fab]
|
|
90
87
|
pull_object : Callable[[str], bytes]
|
|
91
88
|
push_object : Callable[[str, bytes], None]
|
|
92
89
|
confirm_message_received : Callable[[str], None]
|
|
93
90
|
"""
|
|
94
91
|
if authentication_keys is not None:
|
|
95
|
-
log(ERROR, "
|
|
92
|
+
log(ERROR, "SuperNode authentication is not supported for this transport type.")
|
|
96
93
|
with grpc_request_response(
|
|
97
94
|
server_address=server_address,
|
|
98
95
|
insecure=insecure,
|
|
@@ -15,11 +15,10 @@
|
|
|
15
15
|
"""Contextmanager for a gRPC request-response channel to the Flower server."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from collections.abc import Iterator, Sequence
|
|
18
|
+
from collections.abc import Callable, Iterator, Sequence
|
|
19
19
|
from contextlib import contextmanager
|
|
20
20
|
from logging import ERROR
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import Callable, Optional, Union, cast
|
|
23
22
|
|
|
24
23
|
import grpc
|
|
25
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -27,7 +26,6 @@ from cryptography.hazmat.primitives.asymmetric import ec
|
|
|
27
26
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
28
27
|
from flwr.common.constant import HEARTBEAT_CALL_TIMEOUT, HEARTBEAT_DEFAULT_INTERVAL
|
|
29
28
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
|
30
|
-
from flwr.common.heartbeat import HeartbeatSender
|
|
31
29
|
from flwr.common.inflatable_protobuf_utils import (
|
|
32
30
|
make_confirm_message_received_fn_protobuf,
|
|
33
31
|
make_pull_object_fn_protobuf,
|
|
@@ -36,19 +34,24 @@ from flwr.common.inflatable_protobuf_utils import (
|
|
|
36
34
|
from flwr.common.logger import log
|
|
37
35
|
from flwr.common.message import Message, remove_content_from_message
|
|
38
36
|
from flwr.common.retry_invoker import RetryInvoker, _wrap_stub
|
|
39
|
-
from flwr.common.
|
|
40
|
-
|
|
37
|
+
from flwr.common.serde import (
|
|
38
|
+
fab_from_proto,
|
|
39
|
+
message_from_proto,
|
|
40
|
+
message_to_proto,
|
|
41
|
+
run_from_proto,
|
|
41
42
|
)
|
|
42
|
-
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
|
43
43
|
from flwr.common.typing import Fab, Run
|
|
44
44
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
45
45
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
46
|
-
|
|
47
|
-
|
|
46
|
+
ActivateNodeRequest,
|
|
47
|
+
ActivateNodeResponse,
|
|
48
|
+
DeactivateNodeRequest,
|
|
48
49
|
PullMessagesRequest,
|
|
49
50
|
PullMessagesResponse,
|
|
50
51
|
PushMessagesRequest,
|
|
51
52
|
PushMessagesResponse,
|
|
53
|
+
RegisterNodeFleetRequest,
|
|
54
|
+
UnregisterNodeFleetRequest,
|
|
52
55
|
)
|
|
53
56
|
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
|
54
57
|
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
@@ -58,9 +61,11 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
|
58
61
|
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
|
59
62
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
60
63
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
64
|
+
from flwr.supercore.heartbeat import HeartbeatSender
|
|
65
|
+
from flwr.supercore.primitives.asymmetric import generate_key_pairs, public_key_to_bytes
|
|
61
66
|
|
|
62
|
-
from .client_interceptor import AuthenticateClientInterceptor
|
|
63
67
|
from .grpc_adapter import GrpcAdapter
|
|
68
|
+
from .node_auth_client_interceptor import NodeAuthClientInterceptor
|
|
64
69
|
|
|
65
70
|
|
|
66
71
|
@contextmanager
|
|
@@ -69,17 +74,16 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
69
74
|
insecure: bool,
|
|
70
75
|
retry_invoker: RetryInvoker,
|
|
71
76
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
|
72
|
-
root_certificates:
|
|
73
|
-
authentication_keys:
|
|
74
|
-
tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
75
|
-
|
|
76
|
-
adapter_cls:
|
|
77
|
+
root_certificates: bytes | str | None = None,
|
|
78
|
+
authentication_keys: (
|
|
79
|
+
tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] | None
|
|
80
|
+
) = None,
|
|
81
|
+
adapter_cls: type[FleetStub] | type[GrpcAdapter] | None = None,
|
|
77
82
|
) -> Iterator[
|
|
78
83
|
tuple[
|
|
79
|
-
|
|
84
|
+
int,
|
|
85
|
+
Callable[[], tuple[Message, ObjectTree] | None],
|
|
80
86
|
Callable[[Message, ObjectTree], set[str]],
|
|
81
|
-
Callable[[], Optional[int]],
|
|
82
|
-
Callable[[], None],
|
|
83
87
|
Callable[[int], Run],
|
|
84
88
|
Callable[[str, int], Fab],
|
|
85
89
|
Callable[[int, str], bytes],
|
|
@@ -122,11 +126,11 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
122
126
|
|
|
123
127
|
Returns
|
|
124
128
|
-------
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
129
|
+
node_id : int
|
|
130
|
+
receive : Callable[[], Optional[tuple[Message, ObjectTree]]]
|
|
131
|
+
send : Callable[[Message, ObjectTree], set[str]]
|
|
132
|
+
get_run : Callable[[int], Run]
|
|
133
|
+
get_fab : Callable[[str, int], Fab]
|
|
130
134
|
pull_object : Callable[[str], bytes]
|
|
131
135
|
push_object : Callable[[str, bytes], None]
|
|
132
136
|
confirm_message_received : Callable[[str], None]
|
|
@@ -135,13 +139,16 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
135
139
|
root_certificates = Path(root_certificates).read_bytes()
|
|
136
140
|
|
|
137
141
|
# Automatic node auth: generate keys if user didn't provide any
|
|
142
|
+
self_registered = False
|
|
138
143
|
if authentication_keys is None:
|
|
144
|
+
self_registered = True
|
|
139
145
|
authentication_keys = generate_key_pairs()
|
|
140
146
|
|
|
141
147
|
# Always configure auth interceptor, with either user-provided or generated keys
|
|
142
148
|
interceptors: Sequence[grpc.UnaryUnaryClientInterceptor] = [
|
|
143
|
-
|
|
149
|
+
NodeAuthClientInterceptor(*authentication_keys),
|
|
144
150
|
]
|
|
151
|
+
node_pk = public_key_to_bytes(authentication_keys[1])
|
|
145
152
|
channel = create_channel(
|
|
146
153
|
server_address=server_address,
|
|
147
154
|
insecure=insecure,
|
|
@@ -155,12 +162,12 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
155
162
|
if adapter_cls is None:
|
|
156
163
|
adapter_cls = FleetStub
|
|
157
164
|
stub = adapter_cls(channel)
|
|
158
|
-
node:
|
|
165
|
+
node: Node | None = None
|
|
159
166
|
|
|
160
167
|
# Wrap stub
|
|
161
168
|
_wrap_stub(stub, retry_invoker)
|
|
162
169
|
###########################################################################
|
|
163
|
-
#
|
|
170
|
+
# SuperNode functions
|
|
164
171
|
###########################################################################
|
|
165
172
|
|
|
166
173
|
def send_node_heartbeat() -> bool:
|
|
@@ -197,22 +204,26 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
197
204
|
|
|
198
205
|
heartbeat_sender = HeartbeatSender(send_node_heartbeat)
|
|
199
206
|
|
|
200
|
-
def
|
|
201
|
-
"""
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
207
|
+
def register_node() -> None:
|
|
208
|
+
"""Register node with SuperLink."""
|
|
209
|
+
stub.RegisterNode(RegisterNodeFleetRequest(public_key=node_pk))
|
|
210
|
+
|
|
211
|
+
def activate_node() -> int:
|
|
212
|
+
"""Activate node and start heartbeat."""
|
|
213
|
+
req = ActivateNodeRequest(
|
|
214
|
+
public_key=node_pk,
|
|
215
|
+
heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
|
|
205
216
|
)
|
|
206
|
-
|
|
217
|
+
res: ActivateNodeResponse = stub.ActivateNode(req)
|
|
207
218
|
|
|
208
219
|
# Remember the node and start the heartbeat sender
|
|
209
220
|
nonlocal node
|
|
210
|
-
node =
|
|
221
|
+
node = Node(node_id=res.node_id)
|
|
211
222
|
heartbeat_sender.start()
|
|
212
223
|
return node.node_id
|
|
213
224
|
|
|
214
|
-
def
|
|
215
|
-
"""
|
|
225
|
+
def deactivate_node() -> None:
|
|
226
|
+
"""Deactivate node and stop heartbeat."""
|
|
216
227
|
# Get Node
|
|
217
228
|
nonlocal node
|
|
218
229
|
if node is None:
|
|
@@ -223,13 +234,25 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
223
234
|
heartbeat_sender.stop()
|
|
224
235
|
|
|
225
236
|
# Call FleetAPI
|
|
226
|
-
|
|
227
|
-
stub.
|
|
237
|
+
req = DeactivateNodeRequest(node_id=node.node_id)
|
|
238
|
+
stub.DeactivateNode(req)
|
|
239
|
+
|
|
240
|
+
def unregister_node() -> None:
|
|
241
|
+
"""Unregister node from SuperLink."""
|
|
242
|
+
# Get Node
|
|
243
|
+
nonlocal node
|
|
244
|
+
if node is None:
|
|
245
|
+
log(ERROR, "Node instance missing")
|
|
246
|
+
return
|
|
247
|
+
|
|
248
|
+
# Call FleetAPI
|
|
249
|
+
req = UnregisterNodeFleetRequest(node_id=node.node_id)
|
|
250
|
+
stub.UnregisterNode(req)
|
|
228
251
|
|
|
229
252
|
# Cleanup
|
|
230
253
|
node = None
|
|
231
254
|
|
|
232
|
-
def receive() ->
|
|
255
|
+
def receive() -> tuple[Message, ObjectTree] | None:
|
|
233
256
|
"""Pull a message with its ObjectTree from SuperLink."""
|
|
234
257
|
# Get Node
|
|
235
258
|
if node is None:
|
|
@@ -289,7 +312,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
289
312
|
get_fab_request = GetFabRequest(node=node, hash_str=fab_hash, run_id=run_id)
|
|
290
313
|
get_fab_response: GetFabResponse = stub.GetFab(request=get_fab_request)
|
|
291
314
|
|
|
292
|
-
return
|
|
315
|
+
return fab_from_proto(get_fab_response.fab)
|
|
293
316
|
|
|
294
317
|
def pull_object(run_id: int, object_id: str) -> bytes:
|
|
295
318
|
"""Pull the object from the SuperLink."""
|
|
@@ -331,12 +354,14 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
331
354
|
fn(object_id)
|
|
332
355
|
|
|
333
356
|
try:
|
|
357
|
+
if self_registered:
|
|
358
|
+
register_node()
|
|
359
|
+
node_id = activate_node()
|
|
334
360
|
# Yield methods
|
|
335
361
|
yield (
|
|
362
|
+
node_id,
|
|
336
363
|
receive,
|
|
337
364
|
send,
|
|
338
|
-
create_node,
|
|
339
|
-
delete_node,
|
|
340
365
|
get_run,
|
|
341
366
|
get_fab,
|
|
342
367
|
pull_object,
|
|
@@ -351,7 +376,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
351
376
|
if node is not None:
|
|
352
377
|
# Disable retrying
|
|
353
378
|
retry_invoker.max_tries = 1
|
|
354
|
-
|
|
379
|
+
deactivate_node()
|
|
380
|
+
if self_registered:
|
|
381
|
+
unregister_node()
|
|
355
382
|
except grpc.RpcError:
|
|
356
383
|
pass
|
|
357
384
|
channel.close()
|
|
@@ -15,7 +15,8 @@
|
|
|
15
15
|
"""GrpcAdapter implementation."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import
|
|
18
|
+
import signal
|
|
19
|
+
import time
|
|
19
20
|
from logging import DEBUG
|
|
20
21
|
from typing import Any, TypeVar, cast
|
|
21
22
|
|
|
@@ -34,14 +35,18 @@ from flwr.common.constant import (
|
|
|
34
35
|
from flwr.common.version import package_name, package_version
|
|
35
36
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
36
37
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
38
|
+
ActivateNodeRequest,
|
|
39
|
+
ActivateNodeResponse,
|
|
40
|
+
DeactivateNodeRequest,
|
|
41
|
+
DeactivateNodeResponse,
|
|
41
42
|
PullMessagesRequest,
|
|
42
43
|
PullMessagesResponse,
|
|
43
44
|
PushMessagesRequest,
|
|
44
45
|
PushMessagesResponse,
|
|
46
|
+
RegisterNodeFleetRequest,
|
|
47
|
+
RegisterNodeFleetResponse,
|
|
48
|
+
UnregisterNodeFleetRequest,
|
|
49
|
+
UnregisterNodeFleetResponse,
|
|
45
50
|
)
|
|
46
51
|
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
47
52
|
from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub
|
|
@@ -58,6 +63,7 @@ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
|
58
63
|
PushObjectResponse,
|
|
59
64
|
)
|
|
60
65
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
66
|
+
from flwr.supercore.constant import FORCE_EXIT_TIMEOUT_SECONDS
|
|
61
67
|
|
|
62
68
|
T = TypeVar("T", bound=GrpcMessage)
|
|
63
69
|
|
|
@@ -104,7 +110,9 @@ class GrpcAdapter:
|
|
|
104
110
|
DEBUG,
|
|
105
111
|
'Received shutdown signal: exit flag is set to ``"true"``. Exiting...',
|
|
106
112
|
)
|
|
107
|
-
|
|
113
|
+
signal.raise_signal(signal.SIGTERM)
|
|
114
|
+
# Give some time to handle the signal
|
|
115
|
+
time.sleep(FORCE_EXIT_TIMEOUT_SECONDS + 1)
|
|
108
116
|
|
|
109
117
|
# Check the grpc_message_name of the response
|
|
110
118
|
if container_res.grpc_message_name != response_type.__qualname__:
|
|
@@ -118,17 +126,29 @@ class GrpcAdapter:
|
|
|
118
126
|
response.ParseFromString(container_res.grpc_message_content)
|
|
119
127
|
return response
|
|
120
128
|
|
|
121
|
-
def
|
|
122
|
-
self, request:
|
|
123
|
-
) ->
|
|
129
|
+
def RegisterNode( # pylint: disable=C0103
|
|
130
|
+
self, request: RegisterNodeFleetRequest, **kwargs: Any
|
|
131
|
+
) -> RegisterNodeFleetResponse:
|
|
124
132
|
"""."""
|
|
125
|
-
return self._send_and_receive(request,
|
|
133
|
+
return self._send_and_receive(request, RegisterNodeFleetResponse, **kwargs)
|
|
126
134
|
|
|
127
|
-
def
|
|
128
|
-
self, request:
|
|
129
|
-
) ->
|
|
135
|
+
def ActivateNode( # pylint: disable=C0103
|
|
136
|
+
self, request: ActivateNodeRequest, **kwargs: Any
|
|
137
|
+
) -> ActivateNodeResponse:
|
|
130
138
|
"""."""
|
|
131
|
-
return self._send_and_receive(request,
|
|
139
|
+
return self._send_and_receive(request, ActivateNodeResponse, **kwargs)
|
|
140
|
+
|
|
141
|
+
def DeactivateNode( # pylint: disable=C0103
|
|
142
|
+
self, request: DeactivateNodeRequest, **kwargs: Any
|
|
143
|
+
) -> DeactivateNodeResponse:
|
|
144
|
+
"""."""
|
|
145
|
+
return self._send_and_receive(request, DeactivateNodeResponse, **kwargs)
|
|
146
|
+
|
|
147
|
+
def UnregisterNode( # pylint: disable=C0103
|
|
148
|
+
self, request: UnregisterNodeFleetRequest, **kwargs: Any
|
|
149
|
+
) -> UnregisterNodeFleetResponse:
|
|
150
|
+
"""."""
|
|
151
|
+
return self._send_and_receive(request, UnregisterNodeFleetResponse, **kwargs)
|
|
132
152
|
|
|
133
153
|
def SendNodeHeartbeat( # pylint: disable=C0103
|
|
134
154
|
self, request: SendNodeHeartbeatRequest, **kwargs: Any
|
|
@@ -15,7 +15,8 @@
|
|
|
15
15
|
"""Flower client interceptor."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from
|
|
18
|
+
from collections.abc import Callable
|
|
19
|
+
from typing import Any
|
|
19
20
|
|
|
20
21
|
import grpc
|
|
21
22
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -23,14 +24,11 @@ from google.protobuf.message import Message as GrpcMessage
|
|
|
23
24
|
|
|
24
25
|
from flwr.common import now
|
|
25
26
|
from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER
|
|
26
|
-
from flwr.
|
|
27
|
-
public_key_to_bytes,
|
|
28
|
-
sign_message,
|
|
29
|
-
)
|
|
27
|
+
from flwr.supercore.primitives.asymmetric import public_key_to_bytes, sign_message
|
|
30
28
|
|
|
31
29
|
|
|
32
|
-
class
|
|
33
|
-
"""Client interceptor for
|
|
30
|
+
class NodeAuthClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
|
|
31
|
+
"""Client interceptor for node authentication."""
|
|
34
32
|
|
|
35
33
|
def __init__(
|
|
36
34
|
self,
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from logging import WARN
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import cast
|
|
20
20
|
|
|
21
21
|
from flwr.client.client import (
|
|
22
22
|
maybe_call_evaluate,
|
|
@@ -53,7 +53,7 @@ class UnknownServerMessage(Exception):
|
|
|
53
53
|
"""Exception indicating that the received message is unknown."""
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
def handle_control_message(message: Message) -> tuple[
|
|
56
|
+
def handle_control_message(message: Message) -> tuple[Message | None, int]:
|
|
57
57
|
"""Handle control part of the incoming message.
|
|
58
58
|
|
|
59
59
|
Parameters
|
|
@@ -35,14 +35,9 @@ from flwr.common.constant import MessageType
|
|
|
35
35
|
from flwr.common.logger import log
|
|
36
36
|
from flwr.common.secure_aggregation.crypto.shamir import create_shares
|
|
37
37
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
38
|
-
bytes_to_private_key,
|
|
39
|
-
bytes_to_public_key,
|
|
40
38
|
decrypt,
|
|
41
39
|
encrypt,
|
|
42
|
-
generate_key_pairs,
|
|
43
40
|
generate_shared_key,
|
|
44
|
-
private_key_to_bytes,
|
|
45
|
-
public_key_to_bytes,
|
|
46
41
|
)
|
|
47
42
|
from flwr.common.secure_aggregation.ndarrays_arithmetic import (
|
|
48
43
|
factor_combine,
|
|
@@ -64,6 +59,13 @@ from flwr.common.secure_aggregation.secaggplus_utils import (
|
|
|
64
59
|
share_keys_plaintext_separate,
|
|
65
60
|
)
|
|
66
61
|
from flwr.common.typing import ConfigRecordValues
|
|
62
|
+
from flwr.supercore.primitives.asymmetric import (
|
|
63
|
+
bytes_to_private_key,
|
|
64
|
+
bytes_to_public_key,
|
|
65
|
+
generate_key_pairs,
|
|
66
|
+
private_key_to_bytes,
|
|
67
|
+
public_key_to_bytes,
|
|
68
|
+
)
|
|
67
69
|
|
|
68
70
|
|
|
69
71
|
@dataclass
|
|
@@ -110,9 +112,9 @@ class SecAggPlusState:
|
|
|
110
112
|
updated_values = [
|
|
111
113
|
tuple(values[i : i + 2]) for i in range(0, len(values), 2)
|
|
112
114
|
]
|
|
113
|
-
new_v = dict(zip(keys, updated_values))
|
|
115
|
+
new_v = dict(zip(keys, updated_values, strict=True))
|
|
114
116
|
else:
|
|
115
|
-
new_v = dict(zip(keys, values))
|
|
117
|
+
new_v = dict(zip(keys, values, strict=True))
|
|
116
118
|
self.__setattr__(k, new_v)
|
|
117
119
|
|
|
118
120
|
def to_dict(self) -> dict[str, ConfigRecordValues]:
|
|
@@ -424,7 +426,7 @@ def _collect_masked_vectors(
|
|
|
424
426
|
raise ValueError("Not enough available neighbour clients.")
|
|
425
427
|
|
|
426
428
|
# Decrypt ciphertexts, verify their sources, and store shares.
|
|
427
|
-
for src, ciphertext in zip(srcs, ciphertexts):
|
|
429
|
+
for src, ciphertext in zip(srcs, ciphertexts, strict=True):
|
|
428
430
|
shared_key = state.ss2_dict[src]
|
|
429
431
|
plaintext = decrypt(shared_key, ciphertext)
|
|
430
432
|
actual_src, dst, rd_seed_share, sk1_share = share_keys_plaintext_separate(
|