flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +7 -0
- flwr/cli/build.py +150 -0
- flwr/cli/config_utils.py +219 -0
- flwr/cli/example.py +3 -1
- flwr/cli/install.py +227 -0
- flwr/cli/new/new.py +179 -48
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/README.md.tpl +1 -5
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
- flwr/cli/run/run.py +168 -17
- flwr/cli/utils.py +75 -4
- flwr/client/__init__.py +6 -1
- flwr/client/app.py +239 -248
- flwr/client/client_app.py +70 -9
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +97 -0
- flwr/client/grpc_client/connection.py +18 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +127 -33
- flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
- flwr/client/heartbeat.py +74 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +7 -7
- flwr/client/mod/__init__.py +5 -5
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/comms_mods.py +4 -4
- flwr/client/mod/localdp_mod.py +9 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +177 -157
- flwr/client/supernode/__init__.py +26 -0
- flwr/client/supernode/app.py +464 -0
- flwr/client/typing.py +1 -0
- flwr/common/__init__.py +13 -11
- flwr/common/address.py +1 -1
- flwr/common/config.py +193 -0
- flwr/common/constant.py +42 -1
- flwr/common/context.py +26 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +6 -2
- flwr/common/logger.py +79 -8
- flwr/common/message.py +167 -105
- flwr/common/object_ref.py +126 -25
- flwr/common/record/__init__.py +1 -1
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/recordset_compat.py +8 -1
- flwr/common/retry_invoker.py +25 -13
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +209 -3
- flwr/common/telemetry.py +25 -0
- flwr/common/typing.py +38 -0
- flwr/common/version.py +14 -0
- flwr/proto/clientappio_pb2.py +41 -0
- flwr/proto/clientappio_pb2.pyi +110 -0
- flwr/proto/clientappio_pb2_grpc.py +101 -0
- flwr/proto/clientappio_pb2_grpc.pyi +40 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +26 -19
- flwr/proto/driver_pb2.pyi +34 -0
- flwr/proto/driver_pb2_grpc.py +70 -0
- flwr/proto/driver_pb2_grpc.pyi +28 -0
- flwr/proto/exec_pb2.py +43 -0
- flwr/proto/exec_pb2.pyi +95 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +29 -23
- flwr/proto/fleet_pb2.pyi +33 -0
- flwr/proto/fleet_pb2_grpc.py +102 -0
- flwr/proto/fleet_pb2_grpc.pyi +35 -0
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +122 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/proto/run_pb2.py +35 -0
- flwr/proto/run_pb2.pyi +76 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +7 -8
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/__init__.py +4 -8
- flwr/server/app.py +298 -350
- flwr/server/compat/app.py +6 -57
- flwr/server/compat/app_utils.py +5 -4
- flwr/server/compat/driver_client_proxy.py +29 -48
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/__init__.py +2 -0
- flwr/server/driver/driver.py +22 -132
- flwr/server/driver/grpc_driver.py +224 -74
- flwr/server/driver/inmemory_driver.py +183 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +121 -34
- flwr/server/server.py +11 -7
- flwr/server/server_app.py +59 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +3 -3
- flwr/server/strategy/dp_fixed_clipping.py +4 -3
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +51 -4
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +104 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
- flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
- flwr/server/superlink/fleet/vce/vce_api.py +190 -127
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +159 -42
- flwr/server/superlink/state/sqlite_state.py +243 -39
- flwr/server/superlink/state/state.py +81 -6
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +62 -0
- flwr/server/typing.py +2 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +67 -25
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
- flwr/simulation/__init__.py +7 -4
- flwr/simulation/app.py +67 -36
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +20 -46
- flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
- flwr/simulation/run_simulation.py +308 -92
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +184 -0
- flwr/superexec/deployment.py +185 -0
- flwr/superexec/exec_grpc.py +55 -0
- flwr/superexec/exec_servicer.py +70 -0
- flwr/superexec/executor.py +75 -0
- flwr/superexec/simulation.py +193 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
- flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
- flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
- flwr/cli/flower_toml.py +0 -140
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
- flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
- flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
- flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
flwr/client/client_app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,18 +15,71 @@
|
|
|
15
15
|
"""Flower ClientApp."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import inspect
|
|
18
19
|
from typing import Callable, List, Optional
|
|
19
20
|
|
|
21
|
+
from flwr.client.client import Client
|
|
20
22
|
from flwr.client.message_handler.message_handler import (
|
|
21
23
|
handle_legacy_message_from_msgtype,
|
|
22
24
|
)
|
|
23
25
|
from flwr.client.mod.utils import make_ffn
|
|
24
|
-
from flwr.client.typing import
|
|
26
|
+
from flwr.client.typing import ClientFnExt, Mod
|
|
25
27
|
from flwr.common import Context, Message, MessageType
|
|
28
|
+
from flwr.common.logger import warn_deprecated_feature, warn_preview_feature
|
|
26
29
|
|
|
27
30
|
from .typing import ClientAppCallable
|
|
28
31
|
|
|
29
32
|
|
|
33
|
+
def _alert_erroneous_client_fn() -> None:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
"A `ClientApp` cannot make use of a `client_fn` that does "
|
|
36
|
+
"not have a signature in the form: `def client_fn(context: "
|
|
37
|
+
"Context)`. You can import the `Context` like this: "
|
|
38
|
+
"`from flwr.common import Context`"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
|
|
43
|
+
client_fn_args = inspect.signature(client_fn).parameters
|
|
44
|
+
first_arg = list(client_fn_args.keys())[0]
|
|
45
|
+
|
|
46
|
+
if len(client_fn_args) != 1:
|
|
47
|
+
_alert_erroneous_client_fn()
|
|
48
|
+
|
|
49
|
+
first_arg_type = client_fn_args[first_arg].annotation
|
|
50
|
+
|
|
51
|
+
if first_arg_type is str or first_arg == "cid":
|
|
52
|
+
# Warn previous signature for `client_fn` seems to be used
|
|
53
|
+
warn_deprecated_feature(
|
|
54
|
+
"`client_fn` now expects a signature `def client_fn(context: Context)`."
|
|
55
|
+
"The provided `client_fn` has signature: "
|
|
56
|
+
f"{dict(client_fn_args.items())}. You can import the `Context` like this:"
|
|
57
|
+
" `from flwr.common import Context`"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Wrap depcreated client_fn inside a function with the expected signature
|
|
61
|
+
def adaptor_fn(
|
|
62
|
+
context: Context,
|
|
63
|
+
) -> Client: # pylint: disable=unused-argument
|
|
64
|
+
# if patition-id is defined, pass it. Else pass node_id that should
|
|
65
|
+
# always be defined during Context init.
|
|
66
|
+
cid = context.node_config.get("partition-id", context.node_id)
|
|
67
|
+
return client_fn(str(cid)) # type: ignore
|
|
68
|
+
|
|
69
|
+
return adaptor_fn
|
|
70
|
+
|
|
71
|
+
return client_fn
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ClientAppException(Exception):
|
|
75
|
+
"""Exception raised when an exception is raised while executing a ClientApp."""
|
|
76
|
+
|
|
77
|
+
def __init__(self, message: str):
|
|
78
|
+
ex_name = self.__class__.__name__
|
|
79
|
+
self.message = f"\nException {ex_name} occurred. Message: " + message
|
|
80
|
+
super().__init__(self.message)
|
|
81
|
+
|
|
82
|
+
|
|
30
83
|
class ClientApp:
|
|
31
84
|
"""Flower ClientApp.
|
|
32
85
|
|
|
@@ -38,7 +91,7 @@ class ClientApp:
|
|
|
38
91
|
>>> class FlowerClient(NumPyClient):
|
|
39
92
|
>>> # ...
|
|
40
93
|
>>>
|
|
41
|
-
>>> def client_fn(
|
|
94
|
+
>>> def client_fn(context: Context):
|
|
42
95
|
>>> return FlowerClient().to_client()
|
|
43
96
|
>>>
|
|
44
97
|
>>> app = ClientApp(client_fn)
|
|
@@ -55,7 +108,7 @@ class ClientApp:
|
|
|
55
108
|
|
|
56
109
|
def __init__(
|
|
57
110
|
self,
|
|
58
|
-
client_fn: Optional[
|
|
111
|
+
client_fn: Optional[ClientFnExt] = None, # Only for backward compatibility
|
|
59
112
|
mods: Optional[List[Mod]] = None,
|
|
60
113
|
) -> None:
|
|
61
114
|
self._mods: List[Mod] = mods if mods is not None else []
|
|
@@ -64,6 +117,8 @@ class ClientApp:
|
|
|
64
117
|
self._call: Optional[ClientAppCallable] = None
|
|
65
118
|
if client_fn is not None:
|
|
66
119
|
|
|
120
|
+
client_fn = _inspect_maybe_adapt_client_fn_signature(client_fn)
|
|
121
|
+
|
|
67
122
|
def ffn(
|
|
68
123
|
message: Message,
|
|
69
124
|
context: Context,
|
|
@@ -115,7 +170,7 @@ class ClientApp:
|
|
|
115
170
|
>>> def train(message: Message, context: Context) -> Message:
|
|
116
171
|
>>> print("ClientApp training running")
|
|
117
172
|
>>> # Create and return an echo reply message
|
|
118
|
-
>>> return message.create_reply(content=message.content()
|
|
173
|
+
>>> return message.create_reply(content=message.content())
|
|
119
174
|
"""
|
|
120
175
|
|
|
121
176
|
def train_decorator(train_fn: ClientAppCallable) -> ClientAppCallable:
|
|
@@ -123,6 +178,8 @@ class ClientApp:
|
|
|
123
178
|
if self._call:
|
|
124
179
|
raise _registration_error(MessageType.TRAIN)
|
|
125
180
|
|
|
181
|
+
warn_preview_feature("ClientApp-register-train-function")
|
|
182
|
+
|
|
126
183
|
# Register provided function with the ClientApp object
|
|
127
184
|
# Wrap mods around the wrapped step function
|
|
128
185
|
self._train = make_ffn(train_fn, self._mods)
|
|
@@ -143,7 +200,7 @@ class ClientApp:
|
|
|
143
200
|
>>> def evaluate(message: Message, context: Context) -> Message:
|
|
144
201
|
>>> print("ClientApp evaluation running")
|
|
145
202
|
>>> # Create and return an echo reply message
|
|
146
|
-
>>> return message.create_reply(content=message.content()
|
|
203
|
+
>>> return message.create_reply(content=message.content())
|
|
147
204
|
"""
|
|
148
205
|
|
|
149
206
|
def evaluate_decorator(evaluate_fn: ClientAppCallable) -> ClientAppCallable:
|
|
@@ -151,6 +208,8 @@ class ClientApp:
|
|
|
151
208
|
if self._call:
|
|
152
209
|
raise _registration_error(MessageType.EVALUATE)
|
|
153
210
|
|
|
211
|
+
warn_preview_feature("ClientApp-register-evaluate-function")
|
|
212
|
+
|
|
154
213
|
# Register provided function with the ClientApp object
|
|
155
214
|
# Wrap mods around the wrapped step function
|
|
156
215
|
self._evaluate = make_ffn(evaluate_fn, self._mods)
|
|
@@ -171,7 +230,7 @@ class ClientApp:
|
|
|
171
230
|
>>> def query(message: Message, context: Context) -> Message:
|
|
172
231
|
>>> print("ClientApp query running")
|
|
173
232
|
>>> # Create and return an echo reply message
|
|
174
|
-
>>> return message.create_reply(content=message.content()
|
|
233
|
+
>>> return message.create_reply(content=message.content())
|
|
175
234
|
"""
|
|
176
235
|
|
|
177
236
|
def query_decorator(query_fn: ClientAppCallable) -> ClientAppCallable:
|
|
@@ -179,6 +238,8 @@ class ClientApp:
|
|
|
179
238
|
if self._call:
|
|
180
239
|
raise _registration_error(MessageType.QUERY)
|
|
181
240
|
|
|
241
|
+
warn_preview_feature("ClientApp-register-query-function")
|
|
242
|
+
|
|
182
243
|
# Register provided function with the ClientApp object
|
|
183
244
|
# Wrap mods around the wrapped step function
|
|
184
245
|
self._query = make_ffn(query_fn, self._mods)
|
|
@@ -205,7 +266,7 @@ def _registration_error(fn_name: str) -> ValueError:
|
|
|
205
266
|
>>> def client_fn(cid) -> Client:
|
|
206
267
|
>>> return FlowerClient().to_client()
|
|
207
268
|
>>>
|
|
208
|
-
>>> app = ClientApp(
|
|
269
|
+
>>> app = ClientApp(
|
|
209
270
|
>>> client_fn=client_fn,
|
|
210
271
|
>>> )
|
|
211
272
|
|
|
@@ -218,7 +279,7 @@ def _registration_error(fn_name: str) -> ValueError:
|
|
|
218
279
|
>>> print("ClientApp {fn_name} running")
|
|
219
280
|
>>> # Create and return an echo reply message
|
|
220
281
|
>>> return message.create_reply(
|
|
221
|
-
>>> content=message.content()
|
|
282
|
+
>>> content=message.content()
|
|
222
283
|
>>> )
|
|
223
284
|
""",
|
|
224
285
|
)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Client-side part of the GrpcAdapter transport layer."""
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Contextmanager for a GrpcAdapter channel to the Flower server."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from contextlib import contextmanager
|
|
19
|
+
from logging import ERROR
|
|
20
|
+
from typing import Callable, Iterator, Optional, Tuple, Union
|
|
21
|
+
|
|
22
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
23
|
+
|
|
24
|
+
from flwr.client.grpc_rere_client.connection import grpc_request_response
|
|
25
|
+
from flwr.client.grpc_rere_client.grpc_adapter import GrpcAdapter
|
|
26
|
+
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
27
|
+
from flwr.common.logger import log
|
|
28
|
+
from flwr.common.message import Message
|
|
29
|
+
from flwr.common.retry_invoker import RetryInvoker
|
|
30
|
+
from flwr.common.typing import Fab, Run
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@contextmanager
|
|
34
|
+
def grpc_adapter( # pylint: disable=R0913
|
|
35
|
+
server_address: str,
|
|
36
|
+
insecure: bool,
|
|
37
|
+
retry_invoker: RetryInvoker,
|
|
38
|
+
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
|
|
39
|
+
root_certificates: Optional[Union[bytes, str]] = None,
|
|
40
|
+
authentication_keys: Optional[ # pylint: disable=unused-argument
|
|
41
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
42
|
+
] = None,
|
|
43
|
+
) -> Iterator[
|
|
44
|
+
Tuple[
|
|
45
|
+
Callable[[], Optional[Message]],
|
|
46
|
+
Callable[[Message], None],
|
|
47
|
+
Optional[Callable[[], Optional[int]]],
|
|
48
|
+
Optional[Callable[[], None]],
|
|
49
|
+
Optional[Callable[[int], Run]],
|
|
50
|
+
Optional[Callable[[str], Fab]],
|
|
51
|
+
]
|
|
52
|
+
]:
|
|
53
|
+
"""Primitives for request/response-based interaction with a server via GrpcAdapter.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
server_address : str
|
|
58
|
+
The IPv6 address of the server with `http://` or `https://`.
|
|
59
|
+
If the Flower server runs on the same machine
|
|
60
|
+
on port 8080, then `server_address` would be `"http://[::]:8080"`.
|
|
61
|
+
insecure : bool
|
|
62
|
+
Starts an insecure gRPC connection when True. Enables HTTPS connection
|
|
63
|
+
when False, using system certificates if `root_certificates` is None.
|
|
64
|
+
retry_invoker: RetryInvoker
|
|
65
|
+
`RetryInvoker` object that will try to reconnect the client to the server
|
|
66
|
+
after gRPC errors. If None, the client will only try to
|
|
67
|
+
reconnect once after a failure.
|
|
68
|
+
max_message_length : int
|
|
69
|
+
Ignored, only present to preserve API-compatibility.
|
|
70
|
+
root_certificates : Optional[Union[bytes, str]] (default: None)
|
|
71
|
+
Path of the root certificate. If provided, a secure
|
|
72
|
+
connection using the certificates will be established to an SSL-enabled
|
|
73
|
+
Flower server. Bytes won't work for the REST API.
|
|
74
|
+
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
|
|
75
|
+
Client authentication is not supported for this transport type.
|
|
76
|
+
|
|
77
|
+
Returns
|
|
78
|
+
-------
|
|
79
|
+
receive : Callable
|
|
80
|
+
send : Callable
|
|
81
|
+
create_node : Optional[Callable]
|
|
82
|
+
delete_node : Optional[Callable]
|
|
83
|
+
get_run : Optional[Callable]
|
|
84
|
+
get_fab : Optional[Callable]
|
|
85
|
+
"""
|
|
86
|
+
if authentication_keys is not None:
|
|
87
|
+
log(ERROR, "Client authentication is not supported for this transport type.")
|
|
88
|
+
with grpc_request_response(
|
|
89
|
+
server_address=server_address,
|
|
90
|
+
insecure=insecure,
|
|
91
|
+
retry_invoker=retry_invoker,
|
|
92
|
+
max_message_length=max_message_length,
|
|
93
|
+
root_certificates=root_certificates,
|
|
94
|
+
authentication_keys=None, # Authentication is not supported
|
|
95
|
+
adapter_cls=GrpcAdapter,
|
|
96
|
+
) as conn:
|
|
97
|
+
yield conn
|
|
@@ -17,12 +17,15 @@
|
|
|
17
17
|
|
|
18
18
|
import uuid
|
|
19
19
|
from contextlib import contextmanager
|
|
20
|
-
from logging import DEBUG
|
|
20
|
+
from logging import DEBUG, ERROR
|
|
21
21
|
from pathlib import Path
|
|
22
22
|
from queue import Queue
|
|
23
23
|
from typing import Callable, Iterator, Optional, Tuple, Union, cast
|
|
24
24
|
|
|
25
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
26
|
+
|
|
25
27
|
from flwr.common import (
|
|
28
|
+
DEFAULT_TTL,
|
|
26
29
|
GRPC_MAX_MESSAGE_LENGTH,
|
|
27
30
|
ConfigsRecord,
|
|
28
31
|
Message,
|
|
@@ -35,6 +38,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
|
35
38
|
from flwr.common.grpc import create_channel
|
|
36
39
|
from flwr.common.logger import log
|
|
37
40
|
from flwr.common.retry_invoker import RetryInvoker
|
|
41
|
+
from flwr.common.typing import Fab, Run
|
|
38
42
|
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
|
39
43
|
ClientMessage,
|
|
40
44
|
Reason,
|
|
@@ -55,18 +59,23 @@ def on_channel_state_change(channel_connectivity: str) -> None:
|
|
|
55
59
|
|
|
56
60
|
|
|
57
61
|
@contextmanager
|
|
58
|
-
def grpc_connection( # pylint: disable=R0915
|
|
62
|
+
def grpc_connection( # pylint: disable=R0913, R0915
|
|
59
63
|
server_address: str,
|
|
60
64
|
insecure: bool,
|
|
61
65
|
retry_invoker: RetryInvoker, # pylint: disable=unused-argument
|
|
62
66
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
63
67
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
68
|
+
authentication_keys: Optional[ # pylint: disable=unused-argument
|
|
69
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
70
|
+
] = None,
|
|
64
71
|
) -> Iterator[
|
|
65
72
|
Tuple[
|
|
66
73
|
Callable[[], Optional[Message]],
|
|
67
74
|
Callable[[Message], None],
|
|
75
|
+
Optional[Callable[[], Optional[int]]],
|
|
68
76
|
Optional[Callable[[], None]],
|
|
69
|
-
Optional[Callable[[],
|
|
77
|
+
Optional[Callable[[int], Run]],
|
|
78
|
+
Optional[Callable[[str], Fab]],
|
|
70
79
|
]
|
|
71
80
|
]:
|
|
72
81
|
"""Establish a gRPC connection to a gRPC server.
|
|
@@ -94,6 +103,8 @@ def grpc_connection( # pylint: disable=R0915
|
|
|
94
103
|
The PEM-encoded root certificates as a byte string or a path string.
|
|
95
104
|
If provided, a secure connection using the certificates will be
|
|
96
105
|
established to an SSL-enabled Flower server.
|
|
106
|
+
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
|
|
107
|
+
Client authentication is not supported for this transport type.
|
|
97
108
|
|
|
98
109
|
Returns
|
|
99
110
|
-------
|
|
@@ -116,6 +127,8 @@ def grpc_connection( # pylint: disable=R0915
|
|
|
116
127
|
"""
|
|
117
128
|
if isinstance(root_certificates, str):
|
|
118
129
|
root_certificates = Path(root_certificates).read_bytes()
|
|
130
|
+
if authentication_keys is not None:
|
|
131
|
+
log(ERROR, "Client authentication is not supported for this transport type.")
|
|
119
132
|
|
|
120
133
|
channel = create_channel(
|
|
121
134
|
server_address=server_address,
|
|
@@ -180,7 +193,7 @@ def grpc_connection( # pylint: disable=R0915
|
|
|
180
193
|
dst_node_id=0,
|
|
181
194
|
reply_to_message="",
|
|
182
195
|
group_id="",
|
|
183
|
-
ttl=
|
|
196
|
+
ttl=DEFAULT_TTL,
|
|
184
197
|
message_type=message_type,
|
|
185
198
|
),
|
|
186
199
|
content=recordset,
|
|
@@ -223,7 +236,7 @@ def grpc_connection( # pylint: disable=R0915
|
|
|
223
236
|
|
|
224
237
|
try:
|
|
225
238
|
# Yield methods
|
|
226
|
-
yield (receive, send, None, None)
|
|
239
|
+
yield (receive, send, None, None, None, None)
|
|
227
240
|
finally:
|
|
228
241
|
# Make sure to have a final
|
|
229
242
|
channel.close()
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower client interceptor."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import base64
|
|
19
|
+
import collections
|
|
20
|
+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
21
|
+
|
|
22
|
+
import grpc
|
|
23
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
|
+
|
|
25
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
26
|
+
bytes_to_public_key,
|
|
27
|
+
compute_hmac,
|
|
28
|
+
generate_shared_key,
|
|
29
|
+
public_key_to_bytes,
|
|
30
|
+
)
|
|
31
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
32
|
+
CreateNodeRequest,
|
|
33
|
+
DeleteNodeRequest,
|
|
34
|
+
PingRequest,
|
|
35
|
+
PullTaskInsRequest,
|
|
36
|
+
PushTaskResRequest,
|
|
37
|
+
)
|
|
38
|
+
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
|
|
39
|
+
|
|
40
|
+
_PUBLIC_KEY_HEADER = "public-key"
|
|
41
|
+
_AUTH_TOKEN_HEADER = "auth-token"
|
|
42
|
+
|
|
43
|
+
Request = Union[
|
|
44
|
+
CreateNodeRequest,
|
|
45
|
+
DeleteNodeRequest,
|
|
46
|
+
PullTaskInsRequest,
|
|
47
|
+
PushTaskResRequest,
|
|
48
|
+
GetRunRequest,
|
|
49
|
+
PingRequest,
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _get_value_from_tuples(
|
|
54
|
+
key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
|
|
55
|
+
) -> bytes:
|
|
56
|
+
value = next((value for key, value in tuples if key == key_string), "")
|
|
57
|
+
if isinstance(value, str):
|
|
58
|
+
return value.encode()
|
|
59
|
+
|
|
60
|
+
return value
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class _ClientCallDetails(
|
|
64
|
+
collections.namedtuple(
|
|
65
|
+
"_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
|
|
66
|
+
),
|
|
67
|
+
grpc.ClientCallDetails, # type: ignore
|
|
68
|
+
):
|
|
69
|
+
"""Details for each client call.
|
|
70
|
+
|
|
71
|
+
The class will be passed on as the first argument in continuation function.
|
|
72
|
+
In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
|
|
77
|
+
"""Client interceptor for client authentication."""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
private_key: ec.EllipticCurvePrivateKey,
|
|
82
|
+
public_key: ec.EllipticCurvePublicKey,
|
|
83
|
+
):
|
|
84
|
+
self.private_key = private_key
|
|
85
|
+
self.public_key = public_key
|
|
86
|
+
self.shared_secret: Optional[bytes] = None
|
|
87
|
+
self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
|
|
88
|
+
self.encoded_public_key = base64.urlsafe_b64encode(
|
|
89
|
+
public_key_to_bytes(self.public_key)
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def intercept_unary_unary(
|
|
93
|
+
self,
|
|
94
|
+
continuation: Callable[[Any, Any], Any],
|
|
95
|
+
client_call_details: grpc.ClientCallDetails,
|
|
96
|
+
request: Request,
|
|
97
|
+
) -> grpc.Call:
|
|
98
|
+
"""Flower client interceptor.
|
|
99
|
+
|
|
100
|
+
Intercept unary call from client and add necessary authentication header in the
|
|
101
|
+
RPC metadata.
|
|
102
|
+
"""
|
|
103
|
+
metadata = []
|
|
104
|
+
postprocess = False
|
|
105
|
+
if client_call_details.metadata is not None:
|
|
106
|
+
metadata = list(client_call_details.metadata)
|
|
107
|
+
|
|
108
|
+
# Always add the public key header
|
|
109
|
+
metadata.append(
|
|
110
|
+
(
|
|
111
|
+
_PUBLIC_KEY_HEADER,
|
|
112
|
+
self.encoded_public_key,
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if isinstance(request, CreateNodeRequest):
|
|
117
|
+
postprocess = True
|
|
118
|
+
elif isinstance(
|
|
119
|
+
request,
|
|
120
|
+
(
|
|
121
|
+
DeleteNodeRequest,
|
|
122
|
+
PullTaskInsRequest,
|
|
123
|
+
PushTaskResRequest,
|
|
124
|
+
GetRunRequest,
|
|
125
|
+
PingRequest,
|
|
126
|
+
),
|
|
127
|
+
):
|
|
128
|
+
if self.shared_secret is None:
|
|
129
|
+
raise RuntimeError("Failure to compute hmac")
|
|
130
|
+
|
|
131
|
+
metadata.append(
|
|
132
|
+
(
|
|
133
|
+
_AUTH_TOKEN_HEADER,
|
|
134
|
+
base64.urlsafe_b64encode(
|
|
135
|
+
compute_hmac(
|
|
136
|
+
self.shared_secret, request.SerializeToString(True)
|
|
137
|
+
)
|
|
138
|
+
),
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
client_call_details = _ClientCallDetails(
|
|
143
|
+
client_call_details.method,
|
|
144
|
+
client_call_details.timeout,
|
|
145
|
+
metadata,
|
|
146
|
+
client_call_details.credentials,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
response = continuation(client_call_details, request)
|
|
150
|
+
if postprocess:
|
|
151
|
+
server_public_key_bytes = base64.urlsafe_b64decode(
|
|
152
|
+
_get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
|
|
153
|
+
)
|
|
154
|
+
self.server_public_key = bytes_to_public_key(server_public_key_bytes)
|
|
155
|
+
self.shared_secret = generate_shared_key(
|
|
156
|
+
self.private_key, self.server_public_key
|
|
157
|
+
)
|
|
158
|
+
return response
|