flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240509__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +151 -0
- flwr/cli/config_utils.py +18 -46
- flwr/cli/new/new.py +44 -18
- flwr/cli/new/templates/app/code/client.hf.py.tpl +55 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
- flwr/cli/new/templates/app/code/server.hf.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
- flwr/cli/new/templates/app/code/task.hf.py.tpl +87 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +31 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +17 -93
- flwr/client/grpc_client/connection.py +6 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +17 -2
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/localdp_mod.py +9 -3
- flwr/client/rest_client/connection.py +5 -1
- flwr/client/supernode/__init__.py +2 -0
- flwr/client/supernode/app.py +181 -7
- flwr/common/grpc.py +5 -1
- flwr/common/logger.py +37 -4
- flwr/common/message.py +105 -86
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +17 -5
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
- flwr/server/__init__.py +0 -2
- flwr/server/app.py +118 -2
- flwr/server/compat/app.py +5 -56
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -72
- flwr/server/driver/__init__.py +3 -0
- flwr/server/driver/driver.py +12 -242
- flwr/server/driver/grpc_driver.py +315 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +18 -4
- flwr/server/server.py +2 -5
- flwr/server/strategy/dp_adaptive_clipping.py +5 -3
- flwr/server/strategy/dp_fixed_clipping.py +6 -3
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +9 -6
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +76 -8
- flwr/server/superlink/state/sqlite_state.py +116 -11
- flwr/server/superlink/state/state.py +35 -3
- flwr/simulation/__init__.py +2 -2
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +14 -9
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/METADATA +3 -2
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/RECORD +70 -55
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/entry_points.txt +1 -1
- flwr/server/driver/abc_driver.py +0 -140
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240509.dist-info}/WHEEL +0 -0
flwr/server/run_serverapp.py
CHANGED
|
@@ -25,7 +25,7 @@ from flwr.common import Context, EventType, RecordSet, event
|
|
|
25
25
|
from flwr.common.logger import log, update_console_handler
|
|
26
26
|
from flwr.common.object_ref import load_app
|
|
27
27
|
|
|
28
|
-
from .driver
|
|
28
|
+
from .driver import Driver, GrpcDriver
|
|
29
29
|
from .server_app import LoadServerAppError, ServerApp
|
|
30
30
|
|
|
31
31
|
|
|
@@ -128,13 +128,15 @@ def run_server_app() -> None:
|
|
|
128
128
|
server_app_dir = args.dir
|
|
129
129
|
server_app_attr = getattr(args, "server-app")
|
|
130
130
|
|
|
131
|
-
# Initialize
|
|
132
|
-
driver =
|
|
131
|
+
# Initialize GrpcDriver
|
|
132
|
+
driver = GrpcDriver(
|
|
133
133
|
driver_service_address=args.server,
|
|
134
134
|
root_certificates=root_certificates,
|
|
135
|
+
fab_id=args.fab_id,
|
|
136
|
+
fab_version=args.fab_version,
|
|
135
137
|
)
|
|
136
138
|
|
|
137
|
-
# Run the
|
|
139
|
+
# Run the ServerApp with the Driver
|
|
138
140
|
run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
|
|
139
141
|
|
|
140
142
|
# Clean up
|
|
@@ -183,5 +185,17 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
|
183
185
|
"app from there."
|
|
184
186
|
" Default: current working directory.",
|
|
185
187
|
)
|
|
188
|
+
parser.add_argument(
|
|
189
|
+
"--fab-id",
|
|
190
|
+
default=None,
|
|
191
|
+
type=str,
|
|
192
|
+
help="The identifier of the FAB used in the run.",
|
|
193
|
+
)
|
|
194
|
+
parser.add_argument(
|
|
195
|
+
"--fab-version",
|
|
196
|
+
default=None,
|
|
197
|
+
type=str,
|
|
198
|
+
help="The version of the FAB used in the run.",
|
|
199
|
+
)
|
|
186
200
|
|
|
187
201
|
return parser
|
flwr/server/server.py
CHANGED
|
@@ -487,11 +487,8 @@ def run_fl(
|
|
|
487
487
|
log(INFO, "")
|
|
488
488
|
log(INFO, "[SUMMARY]")
|
|
489
489
|
log(INFO, "Run finished %s rounds in %.2fs", config.num_rounds, elapsed_time)
|
|
490
|
-
for
|
|
491
|
-
|
|
492
|
-
log(INFO, "%s", line.strip("\n"))
|
|
493
|
-
else:
|
|
494
|
-
log(INFO, "\t%s", line.strip("\n"))
|
|
490
|
+
for line in io.StringIO(str(hist)):
|
|
491
|
+
log(INFO, "\t%s", line.strip("\n"))
|
|
495
492
|
log(INFO, "")
|
|
496
493
|
|
|
497
494
|
# Graceful shutdown
|
|
@@ -200,7 +200,7 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
200
200
|
|
|
201
201
|
log(
|
|
202
202
|
INFO,
|
|
203
|
-
"aggregate_fit: parameters are clipped by value:
|
|
203
|
+
"aggregate_fit: parameters are clipped by value: %.4f.",
|
|
204
204
|
self.clipping_norm,
|
|
205
205
|
)
|
|
206
206
|
|
|
@@ -234,7 +234,8 @@ class DifferentialPrivacyServerSideAdaptiveClipping(Strategy):
|
|
|
234
234
|
)
|
|
235
235
|
log(
|
|
236
236
|
INFO,
|
|
237
|
-
"aggregate_fit: central DP noise with
|
|
237
|
+
"aggregate_fit: central DP noise with "
|
|
238
|
+
"standard deviation: %.4f added to parameters.",
|
|
238
239
|
compute_stdv(
|
|
239
240
|
self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
|
|
240
241
|
),
|
|
@@ -424,7 +425,8 @@ class DifferentialPrivacyClientSideAdaptiveClipping(Strategy):
|
|
|
424
425
|
)
|
|
425
426
|
log(
|
|
426
427
|
INFO,
|
|
427
|
-
"aggregate_fit: central DP noise with
|
|
428
|
+
"aggregate_fit: central DP noise with "
|
|
429
|
+
"standard deviation: %.4f added to parameters.",
|
|
428
430
|
compute_stdv(
|
|
429
431
|
self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
|
|
430
432
|
),
|
|
@@ -158,7 +158,7 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
|
|
158
158
|
)
|
|
159
159
|
log(
|
|
160
160
|
INFO,
|
|
161
|
-
"aggregate_fit: parameters are clipped by value:
|
|
161
|
+
"aggregate_fit: parameters are clipped by value: %.4f.",
|
|
162
162
|
self.clipping_norm,
|
|
163
163
|
)
|
|
164
164
|
# Convert back to parameters
|
|
@@ -180,7 +180,8 @@ class DifferentialPrivacyServerSideFixedClipping(Strategy):
|
|
|
180
180
|
|
|
181
181
|
log(
|
|
182
182
|
INFO,
|
|
183
|
-
"aggregate_fit: central DP noise with
|
|
183
|
+
"aggregate_fit: central DP noise with "
|
|
184
|
+
"standard deviation: %.4f added to parameters.",
|
|
184
185
|
compute_stdv(
|
|
185
186
|
self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
|
|
186
187
|
),
|
|
@@ -337,11 +338,13 @@ class DifferentialPrivacyClientSideFixedClipping(Strategy):
|
|
|
337
338
|
)
|
|
338
339
|
log(
|
|
339
340
|
INFO,
|
|
340
|
-
"aggregate_fit: central DP noise with
|
|
341
|
+
"aggregate_fit: central DP noise with "
|
|
342
|
+
"standard deviation: %.4f added to parameters.",
|
|
341
343
|
compute_stdv(
|
|
342
344
|
self.noise_multiplier, self.clipping_norm, self.num_sampled_clients
|
|
343
345
|
),
|
|
344
346
|
)
|
|
347
|
+
|
|
345
348
|
return aggregated_params, metrics
|
|
346
349
|
|
|
347
350
|
def aggregate_evaluate(
|
|
@@ -64,7 +64,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
64
64
|
"""Create run ID."""
|
|
65
65
|
log(INFO, "DriverServicer.CreateRun")
|
|
66
66
|
state: State = self.state_factory.state()
|
|
67
|
-
run_id = state.create_run(
|
|
67
|
+
run_id = state.create_run(request.fab_id, request.fab_version)
|
|
68
68
|
return CreateRunResponse(run_id=run_id)
|
|
69
69
|
|
|
70
70
|
def PushTaskIns(
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
import concurrent.futures
|
|
19
19
|
import sys
|
|
20
20
|
from logging import ERROR
|
|
21
|
-
from typing import Any, Callable, Optional, Tuple, Union
|
|
21
|
+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
22
22
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
@@ -162,6 +162,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
|
162
162
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
163
163
|
keepalive_time_ms: int = 210000,
|
|
164
164
|
certificates: Optional[Tuple[bytes, bytes, bytes]] = None,
|
|
165
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
165
166
|
) -> grpc.Server:
|
|
166
167
|
"""Create a gRPC server with a single servicer.
|
|
167
168
|
|
|
@@ -249,6 +250,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
|
249
250
|
# returning RESOURCE_EXHAUSTED status, or None to indicate no limit.
|
|
250
251
|
maximum_concurrent_rpcs=max_concurrent_workers,
|
|
251
252
|
options=options,
|
|
253
|
+
interceptors=interceptors,
|
|
252
254
|
)
|
|
253
255
|
add_servicer_to_server_fn(servicer, server)
|
|
254
256
|
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower server interceptor."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import base64
|
|
19
|
+
from logging import WARNING
|
|
20
|
+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
21
|
+
|
|
22
|
+
import grpc
|
|
23
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
|
+
|
|
25
|
+
from flwr.common.logger import log
|
|
26
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
27
|
+
bytes_to_private_key,
|
|
28
|
+
bytes_to_public_key,
|
|
29
|
+
generate_shared_key,
|
|
30
|
+
verify_hmac,
|
|
31
|
+
)
|
|
32
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
33
|
+
CreateNodeRequest,
|
|
34
|
+
CreateNodeResponse,
|
|
35
|
+
DeleteNodeRequest,
|
|
36
|
+
DeleteNodeResponse,
|
|
37
|
+
GetRunRequest,
|
|
38
|
+
GetRunResponse,
|
|
39
|
+
PingRequest,
|
|
40
|
+
PingResponse,
|
|
41
|
+
PullTaskInsRequest,
|
|
42
|
+
PullTaskInsResponse,
|
|
43
|
+
PushTaskResRequest,
|
|
44
|
+
PushTaskResResponse,
|
|
45
|
+
)
|
|
46
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
47
|
+
from flwr.server.superlink.state import State
|
|
48
|
+
|
|
49
|
+
_PUBLIC_KEY_HEADER = "public-key"
|
|
50
|
+
_AUTH_TOKEN_HEADER = "auth-token"
|
|
51
|
+
|
|
52
|
+
Request = Union[
|
|
53
|
+
CreateNodeRequest,
|
|
54
|
+
DeleteNodeRequest,
|
|
55
|
+
PullTaskInsRequest,
|
|
56
|
+
PushTaskResRequest,
|
|
57
|
+
GetRunRequest,
|
|
58
|
+
PingRequest,
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
Response = Union[
|
|
62
|
+
CreateNodeResponse,
|
|
63
|
+
DeleteNodeResponse,
|
|
64
|
+
PullTaskInsResponse,
|
|
65
|
+
PushTaskResResponse,
|
|
66
|
+
GetRunResponse,
|
|
67
|
+
PingResponse,
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _get_value_from_tuples(
|
|
72
|
+
key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]
|
|
73
|
+
) -> bytes:
|
|
74
|
+
value = next((value for key, value in tuples if key == key_string), "")
|
|
75
|
+
if isinstance(value, str):
|
|
76
|
+
return value.encode()
|
|
77
|
+
|
|
78
|
+
return value
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
82
|
+
"""Server interceptor for client authentication."""
|
|
83
|
+
|
|
84
|
+
def __init__(self, state: State):
|
|
85
|
+
self.state = state
|
|
86
|
+
|
|
87
|
+
self.client_public_keys = state.get_client_public_keys()
|
|
88
|
+
if len(self.client_public_keys) == 0:
|
|
89
|
+
log(WARNING, "Authentication enabled, but no known public keys configured")
|
|
90
|
+
|
|
91
|
+
private_key = self.state.get_server_private_key()
|
|
92
|
+
public_key = self.state.get_server_public_key()
|
|
93
|
+
|
|
94
|
+
if private_key is None or public_key is None:
|
|
95
|
+
raise ValueError("Error loading authentication keys")
|
|
96
|
+
|
|
97
|
+
self.server_private_key = bytes_to_private_key(private_key)
|
|
98
|
+
self.encoded_server_public_key = base64.urlsafe_b64encode(public_key)
|
|
99
|
+
|
|
100
|
+
def intercept_service(
|
|
101
|
+
self,
|
|
102
|
+
continuation: Callable[[Any], Any],
|
|
103
|
+
handler_call_details: grpc.HandlerCallDetails,
|
|
104
|
+
) -> grpc.RpcMethodHandler:
|
|
105
|
+
"""Flower server interceptor authentication logic.
|
|
106
|
+
|
|
107
|
+
Intercept all unary calls from clients and authenticate clients by validating
|
|
108
|
+
auth metadata sent by the client. Continue RPC call if client is authenticated,
|
|
109
|
+
else, terminate RPC call by setting context to abort.
|
|
110
|
+
"""
|
|
111
|
+
# One of the method handlers in
|
|
112
|
+
# `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
|
|
113
|
+
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
|
|
114
|
+
return self._generic_auth_unary_method_handler(method_handler)
|
|
115
|
+
|
|
116
|
+
def _generic_auth_unary_method_handler(
|
|
117
|
+
self, method_handler: grpc.RpcMethodHandler
|
|
118
|
+
) -> grpc.RpcMethodHandler:
|
|
119
|
+
def _generic_method_handler(
|
|
120
|
+
request: Request,
|
|
121
|
+
context: grpc.ServicerContext,
|
|
122
|
+
) -> Response:
|
|
123
|
+
client_public_key_bytes = base64.urlsafe_b64decode(
|
|
124
|
+
_get_value_from_tuples(
|
|
125
|
+
_PUBLIC_KEY_HEADER, context.invocation_metadata()
|
|
126
|
+
)
|
|
127
|
+
)
|
|
128
|
+
if client_public_key_bytes not in self.client_public_keys:
|
|
129
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
130
|
+
|
|
131
|
+
if isinstance(request, CreateNodeRequest):
|
|
132
|
+
return self._create_authenticated_node(
|
|
133
|
+
client_public_key_bytes, request, context
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Verify hmac value
|
|
137
|
+
hmac_value = base64.urlsafe_b64decode(
|
|
138
|
+
_get_value_from_tuples(
|
|
139
|
+
_AUTH_TOKEN_HEADER, context.invocation_metadata()
|
|
140
|
+
)
|
|
141
|
+
)
|
|
142
|
+
public_key = bytes_to_public_key(client_public_key_bytes)
|
|
143
|
+
|
|
144
|
+
if not self._verify_hmac(public_key, request, hmac_value):
|
|
145
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
146
|
+
|
|
147
|
+
# Verify node_id
|
|
148
|
+
node_id = self.state.get_node_id(client_public_key_bytes)
|
|
149
|
+
|
|
150
|
+
if not self._verify_node_id(node_id, request):
|
|
151
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
152
|
+
|
|
153
|
+
return method_handler.unary_unary(request, context) # type: ignore
|
|
154
|
+
|
|
155
|
+
return grpc.unary_unary_rpc_method_handler(
|
|
156
|
+
_generic_method_handler,
|
|
157
|
+
request_deserializer=method_handler.request_deserializer,
|
|
158
|
+
response_serializer=method_handler.response_serializer,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def _verify_node_id(
|
|
162
|
+
self,
|
|
163
|
+
node_id: Optional[int],
|
|
164
|
+
request: Union[
|
|
165
|
+
DeleteNodeRequest,
|
|
166
|
+
PullTaskInsRequest,
|
|
167
|
+
PushTaskResRequest,
|
|
168
|
+
GetRunRequest,
|
|
169
|
+
PingRequest,
|
|
170
|
+
],
|
|
171
|
+
) -> bool:
|
|
172
|
+
if node_id is None:
|
|
173
|
+
return False
|
|
174
|
+
if isinstance(request, PushTaskResRequest):
|
|
175
|
+
if len(request.task_res_list) == 0:
|
|
176
|
+
return False
|
|
177
|
+
return request.task_res_list[0].task.producer.node_id == node_id
|
|
178
|
+
if isinstance(request, GetRunRequest):
|
|
179
|
+
return node_id in self.state.get_nodes(request.run_id)
|
|
180
|
+
return request.node.node_id == node_id
|
|
181
|
+
|
|
182
|
+
def _verify_hmac(
|
|
183
|
+
self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
|
|
184
|
+
) -> bool:
|
|
185
|
+
shared_secret = generate_shared_key(self.server_private_key, public_key)
|
|
186
|
+
return verify_hmac(shared_secret, request.SerializeToString(True), hmac_value)
|
|
187
|
+
|
|
188
|
+
def _create_authenticated_node(
|
|
189
|
+
self,
|
|
190
|
+
public_key_bytes: bytes,
|
|
191
|
+
request: CreateNodeRequest,
|
|
192
|
+
context: grpc.ServicerContext,
|
|
193
|
+
) -> CreateNodeResponse:
|
|
194
|
+
context.send_initial_metadata(
|
|
195
|
+
(
|
|
196
|
+
(
|
|
197
|
+
_PUBLIC_KEY_HEADER,
|
|
198
|
+
self.encoded_server_public_key,
|
|
199
|
+
),
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
node_id = self.state.get_node_id(public_key_bytes)
|
|
204
|
+
|
|
205
|
+
# Handle `CreateNode` here instead of calling the default method handler
|
|
206
|
+
# Return previously assigned `node_id` for the provided `public_key`
|
|
207
|
+
if node_id is not None:
|
|
208
|
+
self.state.acknowledge_ping(node_id, request.ping_interval)
|
|
209
|
+
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
210
|
+
|
|
211
|
+
# No `node_id` exists for the provided `public_key`
|
|
212
|
+
# Handle `CreateNode` here instead of calling the default method handler
|
|
213
|
+
# Note: the innermost `CreateNode` method will never be called
|
|
214
|
+
node_id = self.state.create_node(request.ping_interval, public_key_bytes)
|
|
215
|
+
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""Ray backend for the Fleet API using the Simulation Engine."""
|
|
16
16
|
|
|
17
17
|
import pathlib
|
|
18
|
-
from logging import ERROR, INFO
|
|
18
|
+
from logging import DEBUG, ERROR, INFO, WARNING
|
|
19
19
|
from typing import Callable, Dict, List, Tuple, Union
|
|
20
20
|
|
|
21
21
|
import ray
|
|
@@ -46,7 +46,7 @@ class RayBackend(Backend):
|
|
|
46
46
|
) -> None:
|
|
47
47
|
"""Prepare RayBackend by initialising Ray and creating the ActorPool."""
|
|
48
48
|
log(INFO, "Initialising: %s", self.__class__.__name__)
|
|
49
|
-
log(
|
|
49
|
+
log(DEBUG, "Backend config: %s", backend_config)
|
|
50
50
|
|
|
51
51
|
if not pathlib.Path(work_dir).exists():
|
|
52
52
|
raise ValueError(f"Specified work_dir {work_dir} does not exist.")
|
|
@@ -55,7 +55,10 @@ class RayBackend(Backend):
|
|
|
55
55
|
runtime_env = (
|
|
56
56
|
self._configure_runtime_env(work_dir=work_dir) if work_dir else None
|
|
57
57
|
)
|
|
58
|
-
|
|
58
|
+
if backend_config.get("silent", False):
|
|
59
|
+
init_ray(logging_level=WARNING, log_to_driver=True, runtime_env=runtime_env)
|
|
60
|
+
else:
|
|
61
|
+
init_ray(runtime_env=runtime_env)
|
|
59
62
|
|
|
60
63
|
# Validate client resources
|
|
61
64
|
self.client_resources_key = "client_resources"
|
|
@@ -109,7 +112,7 @@ class RayBackend(Backend):
|
|
|
109
112
|
else:
|
|
110
113
|
client_resources = {"num_cpus": 2, "num_gpus": 0.0}
|
|
111
114
|
log(
|
|
112
|
-
|
|
115
|
+
DEBUG,
|
|
113
116
|
"`%s` not specified in backend config. Applying default setting: %s",
|
|
114
117
|
self.client_resources_key,
|
|
115
118
|
client_resources,
|
|
@@ -129,7 +132,7 @@ class RayBackend(Backend):
|
|
|
129
132
|
async def build(self) -> None:
|
|
130
133
|
"""Build pool of Ray actors that this backend will submit jobs to."""
|
|
131
134
|
await self.pool.add_actors_to_pool(self.pool.actors_capacity)
|
|
132
|
-
log(
|
|
135
|
+
log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors)
|
|
133
136
|
|
|
134
137
|
async def process_message(
|
|
135
138
|
self,
|
|
@@ -173,4 +176,4 @@ class RayBackend(Backend):
|
|
|
173
176
|
"""Terminate all actors in actor pool."""
|
|
174
177
|
await self.pool.terminate_all_actors()
|
|
175
178
|
ray.shutdown()
|
|
176
|
-
log(
|
|
179
|
+
log(DEBUG, "Terminated %s", self.__class__.__name__)
|
|
@@ -293,7 +293,7 @@ def start_vce(
|
|
|
293
293
|
node_states[node_id] = NodeState()
|
|
294
294
|
|
|
295
295
|
# Load backend config
|
|
296
|
-
log(
|
|
296
|
+
log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
|
|
297
297
|
backend_config = json.loads(backend_config_json_stream)
|
|
298
298
|
|
|
299
299
|
try:
|
|
@@ -30,16 +30,24 @@ from flwr.server.utils import validate_task_ins_or_res
|
|
|
30
30
|
from .utils import make_node_unavailable_taskres
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
class InMemoryState(State):
|
|
33
|
+
class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
34
34
|
"""In-memory State implementation."""
|
|
35
35
|
|
|
36
36
|
def __init__(self) -> None:
|
|
37
|
+
|
|
37
38
|
# Map node_id to (online_until, ping_interval)
|
|
38
39
|
self.node_ids: Dict[int, Tuple[float, float]] = {}
|
|
40
|
+
self.public_key_to_node_id: Dict[bytes, int] = {}
|
|
41
|
+
|
|
39
42
|
# Map run_id to (fab_id, fab_version)
|
|
40
43
|
self.run_ids: Dict[int, Tuple[str, str]] = {}
|
|
41
44
|
self.task_ins_store: Dict[UUID, TaskIns] = {}
|
|
42
45
|
self.task_res_store: Dict[UUID, TaskRes] = {}
|
|
46
|
+
|
|
47
|
+
self.client_public_keys: Set[bytes] = set()
|
|
48
|
+
self.server_public_key: Optional[bytes] = None
|
|
49
|
+
self.server_private_key: Optional[bytes] = None
|
|
50
|
+
|
|
43
51
|
self.lock = threading.Lock()
|
|
44
52
|
|
|
45
53
|
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
|
|
@@ -202,23 +210,46 @@ class InMemoryState(State):
|
|
|
202
210
|
"""
|
|
203
211
|
return len(self.task_res_store)
|
|
204
212
|
|
|
205
|
-
def create_node(
|
|
213
|
+
def create_node(
|
|
214
|
+
self, ping_interval: float, public_key: Optional[bytes] = None
|
|
215
|
+
) -> int:
|
|
206
216
|
"""Create, store in state, and return `node_id`."""
|
|
207
217
|
# Sample a random int64 as node_id
|
|
208
218
|
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
209
219
|
|
|
210
220
|
with self.lock:
|
|
211
|
-
if node_id
|
|
212
|
-
|
|
213
|
-
return
|
|
214
|
-
|
|
215
|
-
|
|
221
|
+
if node_id in self.node_ids:
|
|
222
|
+
log(ERROR, "Unexpected node registration failure.")
|
|
223
|
+
return 0
|
|
224
|
+
|
|
225
|
+
if public_key is not None:
|
|
226
|
+
if (
|
|
227
|
+
public_key in self.public_key_to_node_id
|
|
228
|
+
or node_id in self.public_key_to_node_id.values()
|
|
229
|
+
):
|
|
230
|
+
log(ERROR, "Unexpected node registration failure.")
|
|
231
|
+
return 0
|
|
232
|
+
|
|
233
|
+
self.public_key_to_node_id[public_key] = node_id
|
|
234
|
+
|
|
235
|
+
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
|
236
|
+
return node_id
|
|
216
237
|
|
|
217
|
-
def delete_node(self, node_id: int) -> None:
|
|
238
|
+
def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
|
|
218
239
|
"""Delete a client node."""
|
|
219
240
|
with self.lock:
|
|
220
241
|
if node_id not in self.node_ids:
|
|
221
242
|
raise ValueError(f"Node {node_id} not found")
|
|
243
|
+
|
|
244
|
+
if public_key is not None:
|
|
245
|
+
if (
|
|
246
|
+
public_key not in self.public_key_to_node_id
|
|
247
|
+
or node_id not in self.public_key_to_node_id.values()
|
|
248
|
+
):
|
|
249
|
+
raise ValueError("Public key or node_id not found")
|
|
250
|
+
|
|
251
|
+
del self.public_key_to_node_id[public_key]
|
|
252
|
+
|
|
222
253
|
del self.node_ids[node_id]
|
|
223
254
|
|
|
224
255
|
def get_nodes(self, run_id: int) -> Set[int]:
|
|
@@ -239,6 +270,10 @@ class InMemoryState(State):
|
|
|
239
270
|
if online_until > current_time
|
|
240
271
|
}
|
|
241
272
|
|
|
273
|
+
def get_node_id(self, client_public_key: bytes) -> Optional[int]:
|
|
274
|
+
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
|
|
275
|
+
return self.public_key_to_node_id.get(client_public_key)
|
|
276
|
+
|
|
242
277
|
def create_run(self, fab_id: str, fab_version: str) -> int:
|
|
243
278
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
244
279
|
# Sample a random int64 as run_id
|
|
@@ -251,6 +286,39 @@ class InMemoryState(State):
|
|
|
251
286
|
log(ERROR, "Unexpected run creation failure.")
|
|
252
287
|
return 0
|
|
253
288
|
|
|
289
|
+
def store_server_private_public_key(
|
|
290
|
+
self, private_key: bytes, public_key: bytes
|
|
291
|
+
) -> None:
|
|
292
|
+
"""Store `server_private_key` and `server_public_key` in state."""
|
|
293
|
+
with self.lock:
|
|
294
|
+
if self.server_private_key is None and self.server_public_key is None:
|
|
295
|
+
self.server_private_key = private_key
|
|
296
|
+
self.server_public_key = public_key
|
|
297
|
+
else:
|
|
298
|
+
raise RuntimeError("Server private and public key already set")
|
|
299
|
+
|
|
300
|
+
def get_server_private_key(self) -> Optional[bytes]:
|
|
301
|
+
"""Retrieve `server_private_key` in urlsafe bytes."""
|
|
302
|
+
return self.server_private_key
|
|
303
|
+
|
|
304
|
+
def get_server_public_key(self) -> Optional[bytes]:
|
|
305
|
+
"""Retrieve `server_public_key` in urlsafe bytes."""
|
|
306
|
+
return self.server_public_key
|
|
307
|
+
|
|
308
|
+
def store_client_public_keys(self, public_keys: Set[bytes]) -> None:
|
|
309
|
+
"""Store a set of `client_public_keys` in state."""
|
|
310
|
+
with self.lock:
|
|
311
|
+
self.client_public_keys = public_keys
|
|
312
|
+
|
|
313
|
+
def store_client_public_key(self, public_key: bytes) -> None:
|
|
314
|
+
"""Store a `client_public_key` in state."""
|
|
315
|
+
with self.lock:
|
|
316
|
+
self.client_public_keys.add(public_key)
|
|
317
|
+
|
|
318
|
+
def get_client_public_keys(self) -> Set[bytes]:
|
|
319
|
+
"""Retrieve all currently stored `client_public_keys` as a set."""
|
|
320
|
+
return self.client_public_keys
|
|
321
|
+
|
|
254
322
|
def get_run(self, run_id: int) -> Tuple[int, str, str]:
|
|
255
323
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
256
324
|
with self.lock:
|