flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +7 -0
- flwr/cli/build.py +150 -0
- flwr/cli/config_utils.py +219 -0
- flwr/cli/example.py +3 -1
- flwr/cli/install.py +227 -0
- flwr/cli/new/new.py +179 -48
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/README.md.tpl +1 -5
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
- flwr/cli/run/run.py +168 -17
- flwr/cli/utils.py +75 -4
- flwr/client/__init__.py +6 -1
- flwr/client/app.py +239 -248
- flwr/client/client_app.py +70 -9
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +97 -0
- flwr/client/grpc_client/connection.py +18 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +127 -33
- flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
- flwr/client/heartbeat.py +74 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +7 -7
- flwr/client/mod/__init__.py +5 -5
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/comms_mods.py +4 -4
- flwr/client/mod/localdp_mod.py +9 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +177 -157
- flwr/client/supernode/__init__.py +26 -0
- flwr/client/supernode/app.py +464 -0
- flwr/client/typing.py +1 -0
- flwr/common/__init__.py +13 -11
- flwr/common/address.py +1 -1
- flwr/common/config.py +193 -0
- flwr/common/constant.py +42 -1
- flwr/common/context.py +26 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +6 -2
- flwr/common/logger.py +79 -8
- flwr/common/message.py +167 -105
- flwr/common/object_ref.py +126 -25
- flwr/common/record/__init__.py +1 -1
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/recordset_compat.py +8 -1
- flwr/common/retry_invoker.py +25 -13
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +209 -3
- flwr/common/telemetry.py +25 -0
- flwr/common/typing.py +38 -0
- flwr/common/version.py +14 -0
- flwr/proto/clientappio_pb2.py +41 -0
- flwr/proto/clientappio_pb2.pyi +110 -0
- flwr/proto/clientappio_pb2_grpc.py +101 -0
- flwr/proto/clientappio_pb2_grpc.pyi +40 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +26 -19
- flwr/proto/driver_pb2.pyi +34 -0
- flwr/proto/driver_pb2_grpc.py +70 -0
- flwr/proto/driver_pb2_grpc.pyi +28 -0
- flwr/proto/exec_pb2.py +43 -0
- flwr/proto/exec_pb2.pyi +95 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +29 -23
- flwr/proto/fleet_pb2.pyi +33 -0
- flwr/proto/fleet_pb2_grpc.py +102 -0
- flwr/proto/fleet_pb2_grpc.pyi +35 -0
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +122 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/proto/run_pb2.py +35 -0
- flwr/proto/run_pb2.pyi +76 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +7 -8
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/__init__.py +4 -8
- flwr/server/app.py +298 -350
- flwr/server/compat/app.py +6 -57
- flwr/server/compat/app_utils.py +5 -4
- flwr/server/compat/driver_client_proxy.py +29 -48
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/__init__.py +2 -0
- flwr/server/driver/driver.py +22 -132
- flwr/server/driver/grpc_driver.py +224 -74
- flwr/server/driver/inmemory_driver.py +183 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +121 -34
- flwr/server/server.py +11 -7
- flwr/server/server_app.py +59 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +3 -3
- flwr/server/strategy/dp_fixed_clipping.py +4 -3
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +51 -4
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +104 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
- flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
- flwr/server/superlink/fleet/vce/vce_api.py +190 -127
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +159 -42
- flwr/server/superlink/state/sqlite_state.py +243 -39
- flwr/server/superlink/state/state.py +81 -6
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +62 -0
- flwr/server/typing.py +2 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +67 -25
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
- flwr/simulation/__init__.py +7 -4
- flwr/simulation/app.py +67 -36
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +20 -46
- flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
- flwr/simulation/run_simulation.py +308 -92
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +184 -0
- flwr/superexec/deployment.py +185 -0
- flwr/superexec/exec_grpc.py +55 -0
- flwr/superexec/exec_servicer.py +70 -0
- flwr/superexec/executor.py +75 -0
- flwr/superexec/simulation.py +193 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
- flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
- flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
- flwr/cli/flower_toml.py +0 -140
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
- flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
- flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
- flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower server interceptor."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import base64
|
|
19
|
+
from logging import WARNING
|
|
20
|
+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
21
|
+
|
|
22
|
+
import grpc
|
|
23
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
|
+
|
|
25
|
+
from flwr.common.logger import log
|
|
26
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
27
|
+
bytes_to_private_key,
|
|
28
|
+
bytes_to_public_key,
|
|
29
|
+
generate_shared_key,
|
|
30
|
+
verify_hmac,
|
|
31
|
+
)
|
|
32
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
33
|
+
CreateNodeRequest,
|
|
34
|
+
CreateNodeResponse,
|
|
35
|
+
DeleteNodeRequest,
|
|
36
|
+
DeleteNodeResponse,
|
|
37
|
+
PingRequest,
|
|
38
|
+
PingResponse,
|
|
39
|
+
PullTaskInsRequest,
|
|
40
|
+
PullTaskInsResponse,
|
|
41
|
+
PushTaskResRequest,
|
|
42
|
+
PushTaskResResponse,
|
|
43
|
+
)
|
|
44
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
45
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
46
|
+
from flwr.server.superlink.state import State
|
|
47
|
+
|
|
48
|
+
_PUBLIC_KEY_HEADER = "public-key"
|
|
49
|
+
_AUTH_TOKEN_HEADER = "auth-token"
|
|
50
|
+
|
|
51
|
+
Request = Union[
|
|
52
|
+
CreateNodeRequest,
|
|
53
|
+
DeleteNodeRequest,
|
|
54
|
+
PullTaskInsRequest,
|
|
55
|
+
PushTaskResRequest,
|
|
56
|
+
GetRunRequest,
|
|
57
|
+
PingRequest,
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
Response = Union[
|
|
61
|
+
CreateNodeResponse,
|
|
62
|
+
DeleteNodeResponse,
|
|
63
|
+
PullTaskInsResponse,
|
|
64
|
+
PushTaskResResponse,
|
|
65
|
+
GetRunResponse,
|
|
66
|
+
PingResponse,
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _get_value_from_tuples(
|
|
71
|
+
key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
|
|
72
|
+
) -> bytes:
|
|
73
|
+
value = next((value for key, value in tuples if key == key_string), "")
|
|
74
|
+
if isinstance(value, str):
|
|
75
|
+
return value.encode()
|
|
76
|
+
|
|
77
|
+
return value
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
81
|
+
"""Server interceptor for client authentication."""
|
|
82
|
+
|
|
83
|
+
def __init__(self, state: State):
|
|
84
|
+
self.state = state
|
|
85
|
+
|
|
86
|
+
self.client_public_keys = state.get_client_public_keys()
|
|
87
|
+
if len(self.client_public_keys) == 0:
|
|
88
|
+
log(WARNING, "Authentication enabled, but no known public keys configured")
|
|
89
|
+
|
|
90
|
+
private_key = self.state.get_server_private_key()
|
|
91
|
+
public_key = self.state.get_server_public_key()
|
|
92
|
+
|
|
93
|
+
if private_key is None or public_key is None:
|
|
94
|
+
raise ValueError("Error loading authentication keys")
|
|
95
|
+
|
|
96
|
+
self.server_private_key = bytes_to_private_key(private_key)
|
|
97
|
+
self.encoded_server_public_key = base64.urlsafe_b64encode(public_key)
|
|
98
|
+
|
|
99
|
+
def intercept_service(
|
|
100
|
+
self,
|
|
101
|
+
continuation: Callable[[Any], Any],
|
|
102
|
+
handler_call_details: grpc.HandlerCallDetails,
|
|
103
|
+
) -> grpc.RpcMethodHandler:
|
|
104
|
+
"""Flower server interceptor authentication logic.
|
|
105
|
+
|
|
106
|
+
Intercept all unary calls from clients and authenticate clients by validating
|
|
107
|
+
auth metadata sent by the client. Continue RPC call if client is authenticated,
|
|
108
|
+
else, terminate RPC call by setting context to abort.
|
|
109
|
+
"""
|
|
110
|
+
# One of the method handlers in
|
|
111
|
+
# `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
|
|
112
|
+
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
|
|
113
|
+
return self._generic_auth_unary_method_handler(method_handler)
|
|
114
|
+
|
|
115
|
+
def _generic_auth_unary_method_handler(
|
|
116
|
+
self, method_handler: grpc.RpcMethodHandler
|
|
117
|
+
) -> grpc.RpcMethodHandler:
|
|
118
|
+
def _generic_method_handler(
|
|
119
|
+
request: Request,
|
|
120
|
+
context: grpc.ServicerContext,
|
|
121
|
+
) -> Response:
|
|
122
|
+
client_public_key_bytes = base64.urlsafe_b64decode(
|
|
123
|
+
_get_value_from_tuples(
|
|
124
|
+
_PUBLIC_KEY_HEADER, context.invocation_metadata()
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
if client_public_key_bytes not in self.client_public_keys:
|
|
128
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
129
|
+
|
|
130
|
+
if isinstance(request, CreateNodeRequest):
|
|
131
|
+
return self._create_authenticated_node(
|
|
132
|
+
client_public_key_bytes, request, context
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Verify hmac value
|
|
136
|
+
hmac_value = base64.urlsafe_b64decode(
|
|
137
|
+
_get_value_from_tuples(
|
|
138
|
+
_AUTH_TOKEN_HEADER, context.invocation_metadata()
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
public_key = bytes_to_public_key(client_public_key_bytes)
|
|
142
|
+
|
|
143
|
+
if not self._verify_hmac(public_key, request, hmac_value):
|
|
144
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
145
|
+
|
|
146
|
+
# Verify node_id
|
|
147
|
+
node_id = self.state.get_node_id(client_public_key_bytes)
|
|
148
|
+
|
|
149
|
+
if not self._verify_node_id(node_id, request):
|
|
150
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
151
|
+
|
|
152
|
+
return method_handler.unary_unary(request, context) # type: ignore
|
|
153
|
+
|
|
154
|
+
return grpc.unary_unary_rpc_method_handler(
|
|
155
|
+
_generic_method_handler,
|
|
156
|
+
request_deserializer=method_handler.request_deserializer,
|
|
157
|
+
response_serializer=method_handler.response_serializer,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def _verify_node_id(
|
|
161
|
+
self,
|
|
162
|
+
node_id: Optional[int],
|
|
163
|
+
request: Union[
|
|
164
|
+
DeleteNodeRequest,
|
|
165
|
+
PullTaskInsRequest,
|
|
166
|
+
PushTaskResRequest,
|
|
167
|
+
GetRunRequest,
|
|
168
|
+
PingRequest,
|
|
169
|
+
],
|
|
170
|
+
) -> bool:
|
|
171
|
+
if node_id is None:
|
|
172
|
+
return False
|
|
173
|
+
if isinstance(request, PushTaskResRequest):
|
|
174
|
+
if len(request.task_res_list) == 0:
|
|
175
|
+
return False
|
|
176
|
+
return request.task_res_list[0].task.producer.node_id == node_id
|
|
177
|
+
if isinstance(request, GetRunRequest):
|
|
178
|
+
return node_id in self.state.get_nodes(request.run_id)
|
|
179
|
+
return request.node.node_id == node_id
|
|
180
|
+
|
|
181
|
+
def _verify_hmac(
|
|
182
|
+
self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
|
|
183
|
+
) -> bool:
|
|
184
|
+
shared_secret = generate_shared_key(self.server_private_key, public_key)
|
|
185
|
+
return verify_hmac(shared_secret, request.SerializeToString(True), hmac_value)
|
|
186
|
+
|
|
187
|
+
def _create_authenticated_node(
|
|
188
|
+
self,
|
|
189
|
+
public_key_bytes: bytes,
|
|
190
|
+
request: CreateNodeRequest,
|
|
191
|
+
context: grpc.ServicerContext,
|
|
192
|
+
) -> CreateNodeResponse:
|
|
193
|
+
context.send_initial_metadata(
|
|
194
|
+
(
|
|
195
|
+
(
|
|
196
|
+
_PUBLIC_KEY_HEADER,
|
|
197
|
+
self.encoded_server_public_key,
|
|
198
|
+
),
|
|
199
|
+
)
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
node_id = self.state.get_node_id(public_key_bytes)
|
|
203
|
+
|
|
204
|
+
# Handle `CreateNode` here instead of calling the default method handler
|
|
205
|
+
# Return previously assigned `node_id` for the provided `public_key`
|
|
206
|
+
if node_id is not None:
|
|
207
|
+
self.state.acknowledge_ping(node_id, request.ping_interval)
|
|
208
|
+
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
209
|
+
|
|
210
|
+
# No `node_id` exists for the provided `public_key`
|
|
211
|
+
# Handle `CreateNode` here instead of calling the default method handler
|
|
212
|
+
# Note: the innermost `CreateNode` method will never be called
|
|
213
|
+
node_id = self.state.create_node(request.ping_interval, public_key_bytes)
|
|
214
|
+
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,14 +15,18 @@
|
|
|
15
15
|
"""Fleet API message handlers."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import time
|
|
18
19
|
from typing import List, Optional
|
|
19
20
|
from uuid import UUID
|
|
20
21
|
|
|
22
|
+
from flwr.common.serde import user_config_to_proto
|
|
21
23
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
22
24
|
CreateNodeRequest,
|
|
23
25
|
CreateNodeResponse,
|
|
24
26
|
DeleteNodeRequest,
|
|
25
27
|
DeleteNodeResponse,
|
|
28
|
+
PingRequest,
|
|
29
|
+
PingResponse,
|
|
26
30
|
PullTaskInsRequest,
|
|
27
31
|
PullTaskInsResponse,
|
|
28
32
|
PushTaskResRequest,
|
|
@@ -30,6 +34,11 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
30
34
|
Reconnect,
|
|
31
35
|
)
|
|
32
36
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
37
|
+
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
38
|
+
GetRunRequest,
|
|
39
|
+
GetRunResponse,
|
|
40
|
+
Run,
|
|
41
|
+
)
|
|
33
42
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
34
43
|
from flwr.server.superlink.state import State
|
|
35
44
|
|
|
@@ -40,7 +49,7 @@ def create_node(
|
|
|
40
49
|
) -> CreateNodeResponse:
|
|
41
50
|
"""."""
|
|
42
51
|
# Create node
|
|
43
|
-
node_id = state.create_node()
|
|
52
|
+
node_id = state.create_node(ping_interval=request.ping_interval)
|
|
44
53
|
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
45
54
|
|
|
46
55
|
|
|
@@ -55,6 +64,15 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
|
|
|
55
64
|
return DeleteNodeResponse()
|
|
56
65
|
|
|
57
66
|
|
|
67
|
+
def ping(
|
|
68
|
+
request: PingRequest, # pylint: disable=unused-argument
|
|
69
|
+
state: State, # pylint: disable=unused-argument
|
|
70
|
+
) -> PingResponse:
|
|
71
|
+
"""."""
|
|
72
|
+
res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
|
|
73
|
+
return PingResponse(success=res)
|
|
74
|
+
|
|
75
|
+
|
|
58
76
|
def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse:
|
|
59
77
|
"""Pull TaskIns handler."""
|
|
60
78
|
# Get node_id if client node is not anonymous
|
|
@@ -77,6 +95,9 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
|
|
|
77
95
|
task_res: TaskRes = request.task_res_list[0]
|
|
78
96
|
# pylint: enable=no-member
|
|
79
97
|
|
|
98
|
+
# Set pushed_at (timestamp in seconds)
|
|
99
|
+
task_res.task.pushed_at = time.time()
|
|
100
|
+
|
|
80
101
|
# Store TaskRes in State
|
|
81
102
|
task_id: Optional[UUID] = state.store_task_res(task_res=task_res)
|
|
82
103
|
|
|
@@ -86,3 +107,22 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo
|
|
|
86
107
|
results={str(task_id): 0},
|
|
87
108
|
)
|
|
88
109
|
return response
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_run(
|
|
113
|
+
request: GetRunRequest, state: State # pylint: disable=W0613
|
|
114
|
+
) -> GetRunResponse:
|
|
115
|
+
"""Get run information."""
|
|
116
|
+
run = state.get_run(request.run_id)
|
|
117
|
+
|
|
118
|
+
if run is None:
|
|
119
|
+
return GetRunResponse()
|
|
120
|
+
|
|
121
|
+
return GetRunResponse(
|
|
122
|
+
run=Run(
|
|
123
|
+
run_id=run.run_id,
|
|
124
|
+
fab_id=run.fab_id,
|
|
125
|
+
fab_version=run.fab_version,
|
|
126
|
+
override_config=user_config_to_proto(run.override_config),
|
|
127
|
+
)
|
|
128
|
+
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -21,9 +21,11 @@ from flwr.common.constant import MISSING_EXTRA_REST
|
|
|
21
21
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
22
22
|
CreateNodeRequest,
|
|
23
23
|
DeleteNodeRequest,
|
|
24
|
+
PingRequest,
|
|
24
25
|
PullTaskInsRequest,
|
|
25
26
|
PushTaskResRequest,
|
|
26
27
|
)
|
|
28
|
+
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
|
|
27
29
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
28
30
|
from flwr.server.superlink.state import State
|
|
29
31
|
|
|
@@ -152,11 +154,67 @@ async def push_task_res(request: Request) -> Response: # Check if token is need
|
|
|
152
154
|
)
|
|
153
155
|
|
|
154
156
|
|
|
157
|
+
async def ping(request: Request) -> Response:
|
|
158
|
+
"""Ping."""
|
|
159
|
+
_check_headers(request.headers)
|
|
160
|
+
|
|
161
|
+
# Get the request body as raw bytes
|
|
162
|
+
ping_request_bytes: bytes = await request.body()
|
|
163
|
+
|
|
164
|
+
# Deserialize ProtoBuf
|
|
165
|
+
ping_request_proto = PingRequest()
|
|
166
|
+
ping_request_proto.ParseFromString(ping_request_bytes)
|
|
167
|
+
|
|
168
|
+
# Get state from app
|
|
169
|
+
state: State = app.state.STATE_FACTORY.state()
|
|
170
|
+
|
|
171
|
+
# Handle message
|
|
172
|
+
ping_response_proto = message_handler.ping(request=ping_request_proto, state=state)
|
|
173
|
+
|
|
174
|
+
# Return serialized ProtoBuf
|
|
175
|
+
ping_response_bytes = ping_response_proto.SerializeToString()
|
|
176
|
+
return Response(
|
|
177
|
+
status_code=200,
|
|
178
|
+
content=ping_response_bytes,
|
|
179
|
+
headers={"Content-Type": "application/protobuf"},
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
async def get_run(request: Request) -> Response:
|
|
184
|
+
"""GetRun."""
|
|
185
|
+
_check_headers(request.headers)
|
|
186
|
+
|
|
187
|
+
# Get the request body as raw bytes
|
|
188
|
+
get_run_request_bytes: bytes = await request.body()
|
|
189
|
+
|
|
190
|
+
# Deserialize ProtoBuf
|
|
191
|
+
get_run_request_proto = GetRunRequest()
|
|
192
|
+
get_run_request_proto.ParseFromString(get_run_request_bytes)
|
|
193
|
+
|
|
194
|
+
# Get state from app
|
|
195
|
+
state: State = app.state.STATE_FACTORY.state()
|
|
196
|
+
|
|
197
|
+
# Handle message
|
|
198
|
+
get_run_response_proto = message_handler.get_run(
|
|
199
|
+
request=get_run_request_proto, state=state
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Return serialized ProtoBuf
|
|
203
|
+
get_run_response_bytes = get_run_response_proto.SerializeToString()
|
|
204
|
+
return Response(
|
|
205
|
+
status_code=200,
|
|
206
|
+
content=get_run_response_bytes,
|
|
207
|
+
headers={"Content-Type": "application/protobuf"},
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
155
211
|
routes = [
|
|
156
212
|
Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
|
|
157
213
|
Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
|
|
158
214
|
Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
|
|
159
215
|
Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
|
|
216
|
+
Route("/api/v0/fleet/ping", ping, methods=["POST"]),
|
|
217
|
+
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
|
|
160
218
|
]
|
|
161
219
|
|
|
162
220
|
app: Starlette = Starlette(
|
|
@@ -29,12 +29,12 @@ BackendConfig = Dict[str, Dict[str, ConfigsRecordValues]]
|
|
|
29
29
|
class Backend(ABC):
|
|
30
30
|
"""Abstract base class for a Simulation Engine Backend."""
|
|
31
31
|
|
|
32
|
-
def __init__(self, backend_config: BackendConfig
|
|
32
|
+
def __init__(self, backend_config: BackendConfig) -> None:
|
|
33
33
|
"""Construct a backend."""
|
|
34
34
|
|
|
35
35
|
@abstractmethod
|
|
36
|
-
|
|
37
|
-
"""Build backend
|
|
36
|
+
def build(self) -> None:
|
|
37
|
+
"""Build backend.
|
|
38
38
|
|
|
39
39
|
Different components need to be in place before workers in a backend are ready
|
|
40
40
|
to accept jobs. When this method finishes executing, the backend should be fully
|
|
@@ -54,11 +54,11 @@ class Backend(ABC):
|
|
|
54
54
|
"""Report whether a backend worker is idle and can therefore run a ClientApp."""
|
|
55
55
|
|
|
56
56
|
@abstractmethod
|
|
57
|
-
|
|
57
|
+
def terminate(self) -> None:
|
|
58
58
|
"""Terminate backend."""
|
|
59
59
|
|
|
60
60
|
@abstractmethod
|
|
61
|
-
|
|
61
|
+
def process_message(
|
|
62
62
|
self,
|
|
63
63
|
app: Callable[[], ClientApp],
|
|
64
64
|
message: Message,
|
|
@@ -14,26 +14,24 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Ray backend for the Fleet API using the Simulation Engine."""
|
|
16
16
|
|
|
17
|
-
import
|
|
18
|
-
from
|
|
19
|
-
from typing import Callable, Dict, List, Tuple, Union
|
|
17
|
+
from logging import DEBUG, ERROR
|
|
18
|
+
from typing import Callable, Dict, Tuple, Union
|
|
20
19
|
|
|
21
20
|
import ray
|
|
22
21
|
|
|
23
|
-
from flwr.client.client_app import ClientApp
|
|
22
|
+
from flwr.client.client_app import ClientApp
|
|
23
|
+
from flwr.common.constant import PARTITION_ID_KEY
|
|
24
24
|
from flwr.common.context import Context
|
|
25
25
|
from flwr.common.logger import log
|
|
26
26
|
from flwr.common.message import Message
|
|
27
|
-
from flwr.
|
|
28
|
-
|
|
29
|
-
ClientAppActor,
|
|
30
|
-
init_ray,
|
|
31
|
-
)
|
|
27
|
+
from flwr.common.typing import ConfigsRecordValues
|
|
28
|
+
from flwr.simulation.ray_transport.ray_actor import BasicActorPool, ClientAppActor
|
|
32
29
|
from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
|
|
33
30
|
|
|
34
31
|
from .backend import Backend, BackendConfig
|
|
35
32
|
|
|
36
33
|
ClientResourcesDict = Dict[str, Union[int, float]]
|
|
34
|
+
ActorArgsDict = Dict[str, Union[int, float, Callable[[], None]]]
|
|
37
35
|
|
|
38
36
|
|
|
39
37
|
class RayBackend(Backend):
|
|
@@ -42,52 +40,28 @@ class RayBackend(Backend):
|
|
|
42
40
|
def __init__(
|
|
43
41
|
self,
|
|
44
42
|
backend_config: BackendConfig,
|
|
45
|
-
work_dir: str,
|
|
46
43
|
) -> None:
|
|
47
44
|
"""Prepare RayBackend by initialising Ray and creating the ActorPool."""
|
|
48
|
-
log(
|
|
49
|
-
log(
|
|
45
|
+
log(DEBUG, "Initialising: %s", self.__class__.__name__)
|
|
46
|
+
log(DEBUG, "Backend config: %s", backend_config)
|
|
50
47
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
# Init ray and append working dir if needed
|
|
55
|
-
runtime_env = (
|
|
56
|
-
self._configure_runtime_env(work_dir=work_dir) if work_dir else None
|
|
57
|
-
)
|
|
58
|
-
init_ray(runtime_env=runtime_env)
|
|
48
|
+
# Initialise ray
|
|
49
|
+
self.init_args_key = "init_args"
|
|
50
|
+
self.init_ray(backend_config)
|
|
59
51
|
|
|
60
52
|
# Validate client resources
|
|
61
53
|
self.client_resources_key = "client_resources"
|
|
54
|
+
client_resources = self._validate_client_resources(config=backend_config)
|
|
62
55
|
|
|
63
56
|
# Create actor pool
|
|
64
|
-
|
|
65
|
-
actor_kwargs = {"on_actor_init_fn": enable_tf_gpu_growth} if use_tf else {}
|
|
57
|
+
actor_kwargs = self._validate_actor_arguments(config=backend_config)
|
|
66
58
|
|
|
67
|
-
client_resources = self._validate_client_resources(config=backend_config)
|
|
68
59
|
self.pool = BasicActorPool(
|
|
69
60
|
actor_type=ClientAppActor,
|
|
70
61
|
client_resources=client_resources,
|
|
71
62
|
actor_kwargs=actor_kwargs,
|
|
72
63
|
)
|
|
73
64
|
|
|
74
|
-
def _configure_runtime_env(self, work_dir: str) -> Dict[str, Union[str, List[str]]]:
|
|
75
|
-
"""Return list of files/subdirectories to exclude relative to work_dir.
|
|
76
|
-
|
|
77
|
-
Without this, Ray will push everything to the Ray Cluster.
|
|
78
|
-
"""
|
|
79
|
-
runtime_env: Dict[str, Union[str, List[str]]] = {"working_dir": work_dir}
|
|
80
|
-
|
|
81
|
-
excludes = []
|
|
82
|
-
path = pathlib.Path(work_dir)
|
|
83
|
-
for p in path.rglob("*"):
|
|
84
|
-
# Exclude files need to be relative to the working_dir
|
|
85
|
-
if p.is_file() and not str(p).endswith(".py"):
|
|
86
|
-
excludes.append(str(p.relative_to(path)))
|
|
87
|
-
runtime_env["excludes"] = excludes
|
|
88
|
-
|
|
89
|
-
return runtime_env
|
|
90
|
-
|
|
91
65
|
def _validate_client_resources(self, config: BackendConfig) -> ClientResourcesDict:
|
|
92
66
|
client_resources_config = config.get(self.client_resources_key)
|
|
93
67
|
client_resources: ClientResourcesDict = {}
|
|
@@ -109,7 +83,7 @@ class RayBackend(Backend):
|
|
|
109
83
|
else:
|
|
110
84
|
client_resources = {"num_cpus": 2, "num_gpus": 0.0}
|
|
111
85
|
log(
|
|
112
|
-
|
|
86
|
+
DEBUG,
|
|
113
87
|
"`%s` not specified in backend config. Applying default setting: %s",
|
|
114
88
|
self.client_resources_key,
|
|
115
89
|
client_resources,
|
|
@@ -117,6 +91,29 @@ class RayBackend(Backend):
|
|
|
117
91
|
|
|
118
92
|
return client_resources
|
|
119
93
|
|
|
94
|
+
def _validate_actor_arguments(self, config: BackendConfig) -> ActorArgsDict:
|
|
95
|
+
actor_args_config = config.get("actor", False)
|
|
96
|
+
actor_args: ActorArgsDict = {}
|
|
97
|
+
if actor_args_config:
|
|
98
|
+
use_tf = actor_args.get("tensorflow", False)
|
|
99
|
+
if use_tf:
|
|
100
|
+
actor_args["on_actor_init_fn"] = enable_tf_gpu_growth
|
|
101
|
+
return actor_args
|
|
102
|
+
|
|
103
|
+
def init_ray(self, backend_config: BackendConfig) -> None:
|
|
104
|
+
"""Intialises Ray if not already initialised."""
|
|
105
|
+
if not ray.is_initialized():
|
|
106
|
+
ray_init_args: Dict[
|
|
107
|
+
str,
|
|
108
|
+
ConfigsRecordValues,
|
|
109
|
+
] = {}
|
|
110
|
+
|
|
111
|
+
if backend_config.get(self.init_args_key):
|
|
112
|
+
for k, v in backend_config[self.init_args_key].items():
|
|
113
|
+
ray_init_args[k] = v
|
|
114
|
+
|
|
115
|
+
ray.init(**ray_init_args)
|
|
116
|
+
|
|
120
117
|
@property
|
|
121
118
|
def num_workers(self) -> int:
|
|
122
119
|
"""Return number of actors in pool."""
|
|
@@ -126,12 +123,12 @@ class RayBackend(Backend):
|
|
|
126
123
|
"""Report whether the pool has idle actors."""
|
|
127
124
|
return self.pool.is_actor_available()
|
|
128
125
|
|
|
129
|
-
|
|
126
|
+
def build(self) -> None:
|
|
130
127
|
"""Build pool of Ray actors that this backend will submit jobs to."""
|
|
131
|
-
|
|
132
|
-
log(
|
|
128
|
+
self.pool.add_actors_to_pool(self.pool.actors_capacity)
|
|
129
|
+
log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
|
|
133
130
|
|
|
134
|
-
|
|
131
|
+
def process_message(
|
|
135
132
|
self,
|
|
136
133
|
app: Callable[[], ClientApp],
|
|
137
134
|
message: Message,
|
|
@@ -141,35 +138,35 @@ class RayBackend(Backend):
|
|
|
141
138
|
|
|
142
139
|
Return output message and updated context.
|
|
143
140
|
"""
|
|
144
|
-
partition_id =
|
|
141
|
+
partition_id = context.node_config[PARTITION_ID_KEY]
|
|
145
142
|
|
|
146
143
|
try:
|
|
147
|
-
#
|
|
148
|
-
future =
|
|
144
|
+
# Submit a task to the pool
|
|
145
|
+
future = self.pool.submit(
|
|
149
146
|
lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state),
|
|
150
147
|
(app, message, str(partition_id), context),
|
|
151
148
|
)
|
|
152
149
|
|
|
153
|
-
await future
|
|
154
|
-
|
|
155
150
|
# Fetch result
|
|
156
151
|
(
|
|
157
152
|
out_mssg,
|
|
158
153
|
updated_context,
|
|
159
|
-
) =
|
|
154
|
+
) = self.pool.fetch_result_and_return_actor_to_pool(future)
|
|
160
155
|
|
|
161
156
|
return out_mssg, updated_context
|
|
162
157
|
|
|
163
|
-
except
|
|
158
|
+
except Exception as ex:
|
|
164
159
|
log(
|
|
165
160
|
ERROR,
|
|
166
161
|
"An exception was raised when processing a message by %s",
|
|
167
162
|
self.__class__.__name__,
|
|
168
163
|
)
|
|
169
|
-
|
|
164
|
+
# add actor back into pool
|
|
165
|
+
self.pool.add_actor_back_to_pool(future)
|
|
166
|
+
raise ex
|
|
170
167
|
|
|
171
|
-
|
|
168
|
+
def terminate(self) -> None:
|
|
172
169
|
"""Terminate all actors in actor pool."""
|
|
173
|
-
|
|
170
|
+
self.pool.terminate_all_actors()
|
|
174
171
|
ray.shutdown()
|
|
175
|
-
log(
|
|
172
|
+
log(DEBUG, "Terminated %s", self.__class__.__name__)
|