flwr-nightly 1.14.0.dev20241216__py3-none-any.whl → 1.15.0.dev20250112__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/cli_user_auth_interceptor.py +6 -2
- flwr/cli/log.py +8 -6
- flwr/cli/login/login.py +11 -4
- flwr/cli/ls.py +7 -4
- flwr/cli/new/templates/app/.gitignore.tpl +3 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +7 -2
- flwr/cli/stop.py +3 -2
- flwr/cli/utils.py +83 -14
- flwr/client/app.py +17 -9
- flwr/client/client.py +0 -32
- flwr/client/grpc_rere_client/client_interceptor.py +6 -0
- flwr/client/grpc_rere_client/grpc_adapter.py +16 -0
- flwr/client/message_handler/message_handler.py +0 -2
- flwr/client/numpy_client.py +0 -44
- flwr/client/supernode/app.py +1 -2
- flwr/common/auth_plugin/auth_plugin.py +33 -23
- flwr/common/constant.py +2 -0
- flwr/common/grpc.py +154 -3
- flwr/common/record/recordset.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
- flwr/common/telemetry.py +13 -3
- flwr/common/typing.py +20 -0
- flwr/proto/exec_pb2.py +12 -24
- flwr/proto/exec_pb2.pyi +27 -54
- flwr/proto/fleet_pb2.py +40 -27
- flwr/proto/fleet_pb2.pyi +84 -0
- flwr/proto/fleet_pb2_grpc.py +66 -0
- flwr/proto/fleet_pb2_grpc.pyi +20 -0
- flwr/server/app.py +54 -33
- flwr/server/run_serverapp.py +8 -9
- flwr/server/serverapp/app.py +17 -2
- flwr/server/superlink/driver/serverappio_grpc.py +1 -1
- flwr/server/superlink/driver/serverappio_servicer.py +29 -6
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +16 -0
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -1
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +36 -24
- flwr/server/superlink/linkstate/linkstate.py +14 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +56 -31
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +13 -0
- flwr/simulation/app.py +15 -4
- flwr/simulation/run_simulation.py +35 -7
- flwr/superexec/exec_grpc.py +1 -1
- flwr/superexec/exec_servicer.py +23 -2
- {flwr_nightly-1.14.0.dev20241216.dist-info → flwr_nightly-1.15.0.dev20250112.dist-info}/METADATA +5 -5
- {flwr_nightly-1.14.0.dev20241216.dist-info → flwr_nightly-1.15.0.dev20250112.dist-info}/RECORD +60 -60
- {flwr_nightly-1.14.0.dev20241216.dist-info → flwr_nightly-1.15.0.dev20250112.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.14.0.dev20241216.dist-info → flwr_nightly-1.15.0.dev20250112.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.14.0.dev20241216.dist-info → flwr_nightly-1.15.0.dev20250112.dist-info}/entry_points.txt +0 -0
flwr/server/run_serverapp.py
CHANGED
@@ -15,12 +15,12 @@
|
|
15
15
|
"""Run ServerApp."""
|
16
16
|
|
17
17
|
|
18
|
-
import sys
|
19
18
|
from logging import DEBUG, ERROR
|
20
19
|
from typing import Optional
|
21
20
|
|
22
|
-
from flwr.common import Context
|
23
|
-
from flwr.common.
|
21
|
+
from flwr.common import Context, EventType, event
|
22
|
+
from flwr.common.exit_handlers import register_exit_handlers
|
23
|
+
from flwr.common.logger import log
|
24
24
|
from flwr.common.object_ref import load_app
|
25
25
|
|
26
26
|
from .driver import Driver
|
@@ -66,12 +66,11 @@ def run(
|
|
66
66
|
return context
|
67
67
|
|
68
68
|
|
69
|
-
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
|
70
69
|
def run_server_app() -> None:
|
71
70
|
"""Run Flower server app."""
|
72
|
-
|
73
|
-
|
74
|
-
|
71
|
+
event(EventType.RUN_SERVER_APP_ENTER)
|
72
|
+
log(
|
73
|
+
ERROR,
|
74
|
+
"The command `flower-server-app` has been replaced by `flwr run`.",
|
75
75
|
)
|
76
|
-
|
77
|
-
sys.exit()
|
76
|
+
register_exit_handlers(event_type=EventType.RUN_SERVER_APP_LEAVE)
|
flwr/server/serverapp/app.py
CHANGED
@@ -25,6 +25,7 @@ from typing import Optional
|
|
25
25
|
|
26
26
|
from flwr.cli.config_utils import get_fab_metadata
|
27
27
|
from flwr.cli.install import install_from_fab
|
28
|
+
from flwr.cli.utils import get_sha256_hash
|
28
29
|
from flwr.common.args import add_args_flwr_app_common
|
29
30
|
from flwr.common.config import (
|
30
31
|
get_flwr_dir,
|
@@ -51,6 +52,7 @@ from flwr.common.serde import (
|
|
51
52
|
run_from_proto,
|
52
53
|
run_status_to_proto,
|
53
54
|
)
|
55
|
+
from flwr.common.telemetry import EventType, event
|
54
56
|
from flwr.common.typing import RunNotRunningException, RunStatus
|
55
57
|
from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
|
56
58
|
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
@@ -113,7 +115,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
113
115
|
# Resolve directory where FABs are installed
|
114
116
|
flwr_dir_ = get_flwr_dir(flwr_dir)
|
115
117
|
log_uploader = None
|
116
|
-
|
118
|
+
success = True
|
119
|
+
hash_run_id = None
|
117
120
|
while True:
|
118
121
|
|
119
122
|
try:
|
@@ -129,6 +132,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
129
132
|
run = run_from_proto(res.run)
|
130
133
|
fab = fab_from_proto(res.fab)
|
131
134
|
|
135
|
+
hash_run_id = get_sha256_hash(run.run_id)
|
136
|
+
|
132
137
|
driver.set_run(run.run_id)
|
133
138
|
|
134
139
|
# Start log uploader for this run
|
@@ -171,6 +176,11 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
171
176
|
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
172
177
|
)
|
173
178
|
|
179
|
+
event(
|
180
|
+
EventType.FLWR_SERVERAPP_RUN_ENTER,
|
181
|
+
event_details={"run-id-hash": hash_run_id},
|
182
|
+
)
|
183
|
+
|
174
184
|
# Load and run the ServerApp with the Driver
|
175
185
|
updated_context = run_(
|
176
186
|
driver=driver,
|
@@ -187,17 +197,18 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
187
197
|
_ = driver._stub.PushServerAppOutputs(out_req)
|
188
198
|
|
189
199
|
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
190
|
-
|
191
200
|
except RunNotRunningException:
|
192
201
|
log(INFO, "")
|
193
202
|
log(INFO, "Run ID %s stopped.", run.run_id)
|
194
203
|
log(INFO, "")
|
195
204
|
run_status = None
|
205
|
+
success = False
|
196
206
|
|
197
207
|
except Exception as ex: # pylint: disable=broad-exception-caught
|
198
208
|
exc_entity = "ServerApp"
|
199
209
|
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
200
210
|
run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
|
211
|
+
success = False
|
201
212
|
|
202
213
|
finally:
|
203
214
|
# Stop log uploader for this run and upload final logs
|
@@ -213,6 +224,10 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
213
224
|
run_id=run.run_id, run_status=run_status_proto
|
214
225
|
)
|
215
226
|
)
|
227
|
+
event(
|
228
|
+
EventType.FLWR_SERVERAPP_RUN_LEAVE,
|
229
|
+
event_details={"run-id-hash": hash_run_id, "success": success},
|
230
|
+
)
|
216
231
|
|
217
232
|
# Stop the loop if `flwr-serverapp` is expected to process a single run
|
218
233
|
if run_once:
|
@@ -21,6 +21,7 @@ from typing import Optional
|
|
21
21
|
import grpc
|
22
22
|
|
23
23
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
24
|
+
from flwr.common.grpc import generic_create_grpc_server
|
24
25
|
from flwr.common.logger import log
|
25
26
|
from flwr.proto.serverappio_pb2_grpc import ( # pylint: disable=E0611
|
26
27
|
add_ServerAppIoServicer_to_server,
|
@@ -28,7 +29,6 @@ from flwr.proto.serverappio_pb2_grpc import ( # pylint: disable=E0611
|
|
28
29
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
29
30
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
30
31
|
|
31
|
-
from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
32
32
|
from .serverappio_servicer import ServerAppIoServicer
|
33
33
|
|
34
34
|
|
@@ -118,8 +118,9 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
118
118
|
ffs: Ffs = self.ffs_factory.ffs()
|
119
119
|
fab_hash = ffs.put(fab.content, {})
|
120
120
|
_raise_if(
|
121
|
-
fab_hash != fab.hash_str,
|
122
|
-
|
121
|
+
validation_error=fab_hash != fab.hash_str,
|
122
|
+
request_name="CreateRun",
|
123
|
+
detail=f"FAB ({fab.hash_str}) hash from request doesn't match contents",
|
123
124
|
)
|
124
125
|
else:
|
125
126
|
fab_hash = ""
|
@@ -155,10 +156,23 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
155
156
|
task_ins.task.pushed_at = pushed_at
|
156
157
|
|
157
158
|
# Validate request
|
158
|
-
_raise_if(
|
159
|
+
_raise_if(
|
160
|
+
validation_error=len(request.task_ins_list) == 0,
|
161
|
+
request_name="PushTaskIns",
|
162
|
+
detail="`task_ins_list` must not be empty",
|
163
|
+
)
|
159
164
|
for task_ins in request.task_ins_list:
|
160
165
|
validation_errors = validate_task_ins_or_res(task_ins)
|
161
|
-
_raise_if(
|
166
|
+
_raise_if(
|
167
|
+
validation_error=bool(validation_errors),
|
168
|
+
request_name="PushTaskIns",
|
169
|
+
detail=", ".join(validation_errors),
|
170
|
+
)
|
171
|
+
_raise_if(
|
172
|
+
validation_error=request.run_id != task_ins.run_id,
|
173
|
+
request_name="PushTaskIns",
|
174
|
+
detail="`task_ins` has mismatched `run_id`",
|
175
|
+
)
|
162
176
|
|
163
177
|
# Store each TaskIns
|
164
178
|
task_ids: list[Optional[UUID]] = []
|
@@ -193,6 +207,14 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
193
207
|
# Read from state
|
194
208
|
task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids)
|
195
209
|
|
210
|
+
# Validate request
|
211
|
+
for task_res in task_res_list:
|
212
|
+
_raise_if(
|
213
|
+
validation_error=request.run_id != task_res.run_id,
|
214
|
+
request_name="PullTaskRes",
|
215
|
+
detail="`task_res` has mismatched `run_id`",
|
216
|
+
)
|
217
|
+
|
196
218
|
# Delete the TaskIns/TaskRes pairs if TaskRes is found
|
197
219
|
task_ins_ids_to_delete = {
|
198
220
|
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
@@ -335,6 +357,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
335
357
|
return GetRunStatusResponse(run_status_dict=run_status_dict)
|
336
358
|
|
337
359
|
|
338
|
-
def _raise_if(validation_error: bool, detail: str) -> None:
|
360
|
+
def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
|
361
|
+
"""Raise a `ValueError` with a detailed message if a validation error occurs."""
|
339
362
|
if validation_error:
|
340
|
-
raise ValueError(f"Malformed
|
363
|
+
raise ValueError(f"Malformed {request_name}: {detail}")
|
@@ -15,49 +15,19 @@
|
|
15
15
|
"""Implements utility function to create a gRPC server."""
|
16
16
|
|
17
17
|
|
18
|
-
import
|
19
|
-
import sys
|
20
|
-
from collections.abc import Sequence
|
21
|
-
from logging import ERROR
|
22
|
-
from typing import Any, Callable, Optional, Union
|
18
|
+
from typing import Optional
|
23
19
|
|
24
20
|
import grpc
|
25
21
|
|
26
22
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
27
|
-
from flwr.common.
|
28
|
-
from flwr.common.logger import log
|
23
|
+
from flwr.common.grpc import generic_create_grpc_server
|
29
24
|
from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611
|
30
25
|
add_FlowerServiceServicer_to_server,
|
31
26
|
)
|
32
27
|
from flwr.server.client_manager import ClientManager
|
33
|
-
from flwr.server.superlink.driver.serverappio_servicer import ServerAppIoServicer
|
34
|
-
from flwr.server.superlink.fleet.grpc_adapter.grpc_adapter_servicer import (
|
35
|
-
GrpcAdapterServicer,
|
36
|
-
)
|
37
28
|
from flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer import (
|
38
29
|
FlowerServiceServicer,
|
39
30
|
)
|
40
|
-
from flwr.server.superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
41
|
-
|
42
|
-
INVALID_CERTIFICATES_ERR_MSG = """
|
43
|
-
When setting any of root_certificate, certificate, or private_key,
|
44
|
-
all of them need to be set.
|
45
|
-
"""
|
46
|
-
|
47
|
-
AddServicerToServerFn = Callable[..., Any]
|
48
|
-
|
49
|
-
|
50
|
-
def valid_certificates(certificates: tuple[bytes, bytes, bytes]) -> bool:
|
51
|
-
"""Validate certificates tuple."""
|
52
|
-
is_valid = (
|
53
|
-
all(isinstance(certificate, bytes) for certificate in certificates)
|
54
|
-
and len(certificates) == 3
|
55
|
-
)
|
56
|
-
|
57
|
-
if not is_valid:
|
58
|
-
log(ERROR, INVALID_CERTIFICATES_ERR_MSG)
|
59
|
-
|
60
|
-
return is_valid
|
61
31
|
|
62
32
|
|
63
33
|
def start_grpc_server( # pylint: disable=too-many-arguments,R0917
|
@@ -154,136 +124,3 @@ def start_grpc_server( # pylint: disable=too-many-arguments,R0917
|
|
154
124
|
server.start()
|
155
125
|
|
156
126
|
return server
|
157
|
-
|
158
|
-
|
159
|
-
def generic_create_grpc_server( # pylint: disable=too-many-arguments,R0917
|
160
|
-
servicer_and_add_fn: Union[
|
161
|
-
tuple[FleetServicer, AddServicerToServerFn],
|
162
|
-
tuple[GrpcAdapterServicer, AddServicerToServerFn],
|
163
|
-
tuple[FlowerServiceServicer, AddServicerToServerFn],
|
164
|
-
tuple[ServerAppIoServicer, AddServicerToServerFn],
|
165
|
-
],
|
166
|
-
server_address: str,
|
167
|
-
max_concurrent_workers: int = 1000,
|
168
|
-
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
169
|
-
keepalive_time_ms: int = 210000,
|
170
|
-
certificates: Optional[tuple[bytes, bytes, bytes]] = None,
|
171
|
-
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
172
|
-
) -> grpc.Server:
|
173
|
-
"""Create a gRPC server with a single servicer.
|
174
|
-
|
175
|
-
Parameters
|
176
|
-
----------
|
177
|
-
servicer_and_add_fn : tuple
|
178
|
-
A tuple holding a servicer implementation and a matching
|
179
|
-
add_Servicer_to_server function.
|
180
|
-
server_address : str
|
181
|
-
Server address in the form of HOST:PORT e.g. "[::]:8080"
|
182
|
-
max_concurrent_workers : int
|
183
|
-
Maximum number of clients the server can process before returning
|
184
|
-
RESOURCE_EXHAUSTED status (default: 1000)
|
185
|
-
max_message_length : int
|
186
|
-
Maximum message length that the server can send or receive.
|
187
|
-
Int valued in bytes. -1 means unlimited. (default: GRPC_MAX_MESSAGE_LENGTH)
|
188
|
-
keepalive_time_ms : int
|
189
|
-
Flower uses a default gRPC keepalive time of 210000ms (3 minutes 30 seconds)
|
190
|
-
because some cloud providers (for example, Azure) agressively clean up idle
|
191
|
-
TCP connections by terminating them after some time (4 minutes in the case
|
192
|
-
of Azure). Flower does not use application-level keepalive signals and relies
|
193
|
-
on the assumption that the transport layer will fail in cases where the
|
194
|
-
connection is no longer active. `keepalive_time_ms` can be used to customize
|
195
|
-
the keepalive interval for specific environments. The default Flower gRPC
|
196
|
-
keepalive of 210000 ms (3 minutes 30 seconds) ensures that Flower can keep
|
197
|
-
the long running streaming connection alive in most environments. The actual
|
198
|
-
gRPC default of this setting is 7200000 (2 hours), which results in dropped
|
199
|
-
connections in some cloud environments.
|
200
|
-
|
201
|
-
These settings are related to the issue described here:
|
202
|
-
- https://github.com/grpc/proposal/blob/master/A8-client-side-keepalive.md
|
203
|
-
- https://github.com/grpc/grpc/blob/master/doc/keepalive.md
|
204
|
-
- https://grpc.io/docs/guides/performance/
|
205
|
-
|
206
|
-
Mobile Flower clients may choose to increase this value if their server
|
207
|
-
environment allows long-running idle TCP connections.
|
208
|
-
(default: 210000)
|
209
|
-
certificates : Tuple[bytes, bytes, bytes] (default: None)
|
210
|
-
Tuple containing root certificate, server certificate, and private key to
|
211
|
-
start a secure SSL-enabled server. The tuple is expected to have three bytes
|
212
|
-
elements in the following order:
|
213
|
-
|
214
|
-
* CA certificate.
|
215
|
-
* server certificate.
|
216
|
-
* server private key.
|
217
|
-
interceptors : Optional[Sequence[grpc.ServerInterceptor]] (default: None)
|
218
|
-
A list of gRPC interceptors.
|
219
|
-
|
220
|
-
Returns
|
221
|
-
-------
|
222
|
-
server : grpc.Server
|
223
|
-
A non-running instance of a gRPC server.
|
224
|
-
"""
|
225
|
-
# Check if port is in use
|
226
|
-
if is_port_in_use(server_address):
|
227
|
-
sys.exit(f"Port in server address {server_address} is already in use.")
|
228
|
-
|
229
|
-
# Deconstruct tuple into servicer and function
|
230
|
-
servicer, add_servicer_to_server_fn = servicer_and_add_fn
|
231
|
-
|
232
|
-
# Possible options:
|
233
|
-
# https://github.com/grpc/grpc/blob/v1.43.x/include/grpc/impl/codegen/grpc_types.h
|
234
|
-
options = [
|
235
|
-
# Maximum number of concurrent incoming streams to allow on a http2
|
236
|
-
# connection. Int valued.
|
237
|
-
("grpc.max_concurrent_streams", max(100, max_concurrent_workers)),
|
238
|
-
# Maximum message length that the channel can send.
|
239
|
-
# Int valued, bytes. -1 means unlimited.
|
240
|
-
("grpc.max_send_message_length", max_message_length),
|
241
|
-
# Maximum message length that the channel can receive.
|
242
|
-
# Int valued, bytes. -1 means unlimited.
|
243
|
-
("grpc.max_receive_message_length", max_message_length),
|
244
|
-
# The gRPC default for this setting is 7200000 (2 hours). Flower uses a
|
245
|
-
# customized default of 210000 (3 minutes and 30 seconds) to improve
|
246
|
-
# compatibility with popular cloud providers. Mobile Flower clients may
|
247
|
-
# choose to increase this value if their server environment allows
|
248
|
-
# long-running idle TCP connections.
|
249
|
-
("grpc.keepalive_time_ms", keepalive_time_ms),
|
250
|
-
# Setting this to zero will allow sending unlimited keepalive pings in between
|
251
|
-
# sending actual data frames.
|
252
|
-
("grpc.http2.max_pings_without_data", 0),
|
253
|
-
# Is it permissible to send keepalive pings from the client without
|
254
|
-
# any outstanding streams. More explanation here:
|
255
|
-
# https://github.com/adap/flower/pull/2197
|
256
|
-
("grpc.keepalive_permit_without_calls", 0),
|
257
|
-
]
|
258
|
-
|
259
|
-
server = grpc.server(
|
260
|
-
concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent_workers),
|
261
|
-
# Set the maximum number of concurrent RPCs this server will service before
|
262
|
-
# returning RESOURCE_EXHAUSTED status, or None to indicate no limit.
|
263
|
-
maximum_concurrent_rpcs=max_concurrent_workers,
|
264
|
-
options=options,
|
265
|
-
interceptors=interceptors,
|
266
|
-
)
|
267
|
-
add_servicer_to_server_fn(servicer, server)
|
268
|
-
|
269
|
-
if certificates is not None:
|
270
|
-
if not valid_certificates(certificates):
|
271
|
-
sys.exit(1)
|
272
|
-
|
273
|
-
root_certificate_b, certificate_b, private_key_b = certificates
|
274
|
-
|
275
|
-
server_credentials = grpc.ssl_server_credentials(
|
276
|
-
((private_key_b, certificate_b),),
|
277
|
-
root_certificates=root_certificate_b,
|
278
|
-
# A boolean indicating whether or not to require clients to be
|
279
|
-
# authenticated. May only be True if root_certificates is not None.
|
280
|
-
# We are explicitly setting the current gRPC default to document
|
281
|
-
# the option. For further reference see:
|
282
|
-
# https://grpc.github.io/grpc/python/grpc.html#create-server-credentials
|
283
|
-
require_client_auth=False,
|
284
|
-
)
|
285
|
-
server.add_secure_port(server_address, server_credentials)
|
286
|
-
else:
|
287
|
-
server.add_insecure_port(server_address)
|
288
|
-
|
289
|
-
return server
|
@@ -30,8 +30,12 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
30
30
|
DeleteNodeResponse,
|
31
31
|
PingRequest,
|
32
32
|
PingResponse,
|
33
|
+
PullMessagesRequest,
|
34
|
+
PullMessagesResponse,
|
33
35
|
PullTaskInsRequest,
|
34
36
|
PullTaskInsResponse,
|
37
|
+
PushMessagesRequest,
|
38
|
+
PushMessagesResponse,
|
35
39
|
PushTaskResRequest,
|
36
40
|
PushTaskResResponse,
|
37
41
|
)
|
@@ -95,6 +99,12 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
95
99
|
state=self.state_factory.state(),
|
96
100
|
)
|
97
101
|
|
102
|
+
def PullMessages(
|
103
|
+
self, request: PullMessagesRequest, context: grpc.ServicerContext
|
104
|
+
) -> PullMessagesResponse:
|
105
|
+
"""Pull Messages."""
|
106
|
+
return PullMessagesResponse()
|
107
|
+
|
98
108
|
def PushTaskRes(
|
99
109
|
self, request: PushTaskResRequest, context: grpc.ServicerContext
|
100
110
|
) -> PushTaskResResponse:
|
@@ -118,6 +128,12 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
118
128
|
|
119
129
|
return res
|
120
130
|
|
131
|
+
def PushMessages(
|
132
|
+
self, request: PushMessagesRequest, context: grpc.ServicerContext
|
133
|
+
) -> PushMessagesResponse:
|
134
|
+
"""Push Messages."""
|
135
|
+
return PushMessagesResponse()
|
136
|
+
|
121
137
|
def GetRun(
|
122
138
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
123
139
|
) -> GetRunResponse:
|
@@ -223,5 +223,6 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
223
223
|
# No `node_id` exists for the provided `public_key`
|
224
224
|
# Handle `CreateNode` here instead of calling the default method handler
|
225
225
|
# Note: the innermost `CreateNode` method will never be called
|
226
|
-
node_id = state.create_node(request.ping_interval
|
226
|
+
node_id = state.create_node(request.ping_interval)
|
227
|
+
state.set_node_public_key(node_id, public_key_bytes)
|
227
228
|
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
@@ -182,8 +182,8 @@ def run_api(
|
|
182
182
|
f_stop: threading.Event,
|
183
183
|
) -> None:
|
184
184
|
"""Run the VCE."""
|
185
|
-
taskins_queue:
|
186
|
-
taskres_queue:
|
185
|
+
taskins_queue: Queue[TaskIns] = Queue()
|
186
|
+
taskres_queue: Queue[TaskRes] = Queue()
|
187
187
|
|
188
188
|
try:
|
189
189
|
|
@@ -62,6 +62,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
62
62
|
# Map node_id to (online_until, ping_interval)
|
63
63
|
self.node_ids: dict[int, tuple[float, float]] = {}
|
64
64
|
self.public_key_to_node_id: dict[bytes, int] = {}
|
65
|
+
self.node_id_to_public_key: dict[int, bytes] = {}
|
65
66
|
|
66
67
|
# Map run_id to RunRecord
|
67
68
|
self.run_ids: dict[int, RunRecord] = {}
|
@@ -306,9 +307,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
306
307
|
"""
|
307
308
|
return len(self.task_res_store)
|
308
309
|
|
309
|
-
def create_node(
|
310
|
-
self, ping_interval: float, public_key: Optional[bytes] = None
|
311
|
-
) -> int:
|
310
|
+
def create_node(self, ping_interval: float) -> int:
|
312
311
|
"""Create, store in the link state, and return `node_id`."""
|
313
312
|
# Sample a random int64 as node_id
|
314
313
|
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)
|
@@ -318,33 +317,18 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
318
317
|
log(ERROR, "Unexpected node registration failure.")
|
319
318
|
return 0
|
320
319
|
|
321
|
-
if public_key is not None:
|
322
|
-
if (
|
323
|
-
public_key in self.public_key_to_node_id
|
324
|
-
or node_id in self.public_key_to_node_id.values()
|
325
|
-
):
|
326
|
-
log(ERROR, "Unexpected node registration failure.")
|
327
|
-
return 0
|
328
|
-
|
329
|
-
self.public_key_to_node_id[public_key] = node_id
|
330
|
-
|
331
320
|
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
|
332
321
|
return node_id
|
333
322
|
|
334
|
-
def delete_node(self, node_id: int
|
323
|
+
def delete_node(self, node_id: int) -> None:
|
335
324
|
"""Delete a node."""
|
336
325
|
with self.lock:
|
337
326
|
if node_id not in self.node_ids:
|
338
327
|
raise ValueError(f"Node {node_id} not found")
|
339
328
|
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
or node_id not in self.public_key_to_node_id.values()
|
344
|
-
):
|
345
|
-
raise ValueError("Public key or node_id not found")
|
346
|
-
|
347
|
-
del self.public_key_to_node_id[public_key]
|
329
|
+
# Remove node ID <> public key mappings
|
330
|
+
if pk := self.node_id_to_public_key.pop(node_id, None):
|
331
|
+
del self.public_key_to_node_id[pk]
|
348
332
|
|
349
333
|
del self.node_ids[node_id]
|
350
334
|
|
@@ -366,6 +350,26 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
366
350
|
if online_until > current_time
|
367
351
|
}
|
368
352
|
|
353
|
+
def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
|
354
|
+
"""Set `public_key` for the specified `node_id`."""
|
355
|
+
with self.lock:
|
356
|
+
if node_id not in self.node_ids:
|
357
|
+
raise ValueError(f"Node {node_id} not found")
|
358
|
+
|
359
|
+
if public_key in self.public_key_to_node_id:
|
360
|
+
raise ValueError("Public key already in use")
|
361
|
+
|
362
|
+
self.public_key_to_node_id[public_key] = node_id
|
363
|
+
self.node_id_to_public_key[node_id] = public_key
|
364
|
+
|
365
|
+
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
|
366
|
+
"""Get `public_key` for the specified `node_id`."""
|
367
|
+
with self.lock:
|
368
|
+
if node_id not in self.node_ids:
|
369
|
+
raise ValueError(f"Node {node_id} not found")
|
370
|
+
|
371
|
+
return self.node_id_to_public_key.get(node_id)
|
372
|
+
|
369
373
|
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
370
374
|
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
371
375
|
return self.public_key_to_node_id.get(node_public_key)
|
@@ -430,10 +434,17 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
430
434
|
"""Retrieve `server_public_key` in urlsafe bytes."""
|
431
435
|
return self.server_public_key
|
432
436
|
|
437
|
+
def clear_supernode_auth_keys_and_credentials(self) -> None:
|
438
|
+
"""Clear stored `node_public_keys` and credentials in the link state if any."""
|
439
|
+
with self.lock:
|
440
|
+
self.server_private_key = None
|
441
|
+
self.server_public_key = None
|
442
|
+
self.node_public_keys.clear()
|
443
|
+
|
433
444
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
434
445
|
"""Store a set of `node_public_keys` in the link state."""
|
435
446
|
with self.lock:
|
436
|
-
self.node_public_keys
|
447
|
+
self.node_public_keys.update(public_keys)
|
437
448
|
|
438
449
|
def store_node_public_key(self, public_key: bytes) -> None:
|
439
450
|
"""Store a `node_public_key` in the link state."""
|
@@ -442,7 +453,8 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
|
|
442
453
|
|
443
454
|
def get_node_public_keys(self) -> set[bytes]:
|
444
455
|
"""Retrieve all currently stored `node_public_keys` as a set."""
|
445
|
-
|
456
|
+
with self.lock:
|
457
|
+
return self.node_public_keys.copy()
|
446
458
|
|
447
459
|
def get_run_ids(self) -> set[int]:
|
448
460
|
"""Retrieve all run IDs."""
|
@@ -154,13 +154,11 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
154
154
|
"""Get all TaskIns IDs for the given run_id."""
|
155
155
|
|
156
156
|
@abc.abstractmethod
|
157
|
-
def create_node(
|
158
|
-
self, ping_interval: float, public_key: Optional[bytes] = None
|
159
|
-
) -> int:
|
157
|
+
def create_node(self, ping_interval: float) -> int:
|
160
158
|
"""Create, store in the link state, and return `node_id`."""
|
161
159
|
|
162
160
|
@abc.abstractmethod
|
163
|
-
def delete_node(self, node_id: int
|
161
|
+
def delete_node(self, node_id: int) -> None:
|
164
162
|
"""Remove `node_id` from the link state."""
|
165
163
|
|
166
164
|
@abc.abstractmethod
|
@@ -173,6 +171,14 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
173
171
|
an empty `Set` MUST be returned.
|
174
172
|
"""
|
175
173
|
|
174
|
+
@abc.abstractmethod
|
175
|
+
def set_node_public_key(self, node_id: int, public_key: bytes) -> None:
|
176
|
+
"""Set `public_key` for the specified `node_id`."""
|
177
|
+
|
178
|
+
@abc.abstractmethod
|
179
|
+
def get_node_public_key(self, node_id: int) -> Optional[bytes]:
|
180
|
+
"""Get `public_key` for the specified `node_id`."""
|
181
|
+
|
176
182
|
@abc.abstractmethod
|
177
183
|
def get_node_id(self, node_public_key: bytes) -> Optional[int]:
|
178
184
|
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
|
@@ -284,6 +290,10 @@ class LinkState(abc.ABC): # pylint: disable=R0904
|
|
284
290
|
def get_server_public_key(self) -> Optional[bytes]:
|
285
291
|
"""Retrieve `server_public_key` in urlsafe bytes."""
|
286
292
|
|
293
|
+
@abc.abstractmethod
|
294
|
+
def clear_supernode_auth_keys_and_credentials(self) -> None:
|
295
|
+
"""Clear stored `node_public_keys` and credentials in the link state if any."""
|
296
|
+
|
287
297
|
@abc.abstractmethod
|
288
298
|
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
|
289
299
|
"""Store a set of `node_public_keys` in the link state."""
|