flwr 1.22.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +15 -1
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +95 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
- flwr/cli/build.py +118 -47
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +34 -23
- flwr/cli/ls.py +13 -9
- flwr/cli/new/new.py +187 -35
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +2 -2
- flwr/cli/run/run.py +11 -7
- flwr/cli/stop.py +2 -2
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +260 -0
- flwr/cli/supernode/register.py +185 -0
- flwr/cli/supernode/unregister.py +138 -0
- flwr/cli/utils.py +92 -69
- flwr/client/__init__.py +2 -1
- flwr/client/grpc_adapter_client/connection.py +6 -8
- flwr/client/grpc_rere_client/connection.py +59 -31
- flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
- flwr/client/rest_client/connection.py +82 -37
- flwr/clientapp/__init__.py +1 -2
- flwr/{client/clientapp → clientapp}/utils.py +1 -1
- flwr/common/constant.py +53 -13
- flwr/common/exit/exit_code.py +20 -10
- flwr/common/inflatable_utils.py +10 -10
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +10 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/serde.py +4 -2
- flwr/common/typing.py +7 -6
- flwr/compat/client/app.py +1 -1
- flwr/compat/client/grpc_client/connection.py +2 -2
- flwr/proto/control_pb2.py +48 -35
- flwr/proto/control_pb2.pyi +71 -5
- flwr/proto/control_pb2_grpc.py +102 -0
- flwr/proto/control_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +11 -7
- flwr/proto/fab_pb2.pyi +21 -1
- flwr/proto/fleet_pb2.py +31 -23
- flwr/proto/fleet_pb2.pyi +63 -23
- flwr/proto/fleet_pb2_grpc.py +98 -28
- flwr/proto/fleet_pb2_grpc.pyi +45 -13
- flwr/proto/node_pb2.py +3 -1
- flwr/proto/node_pb2.pyi +48 -0
- flwr/server/app.py +139 -114
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
- flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +18 -5
- flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
- flwr/server/superlink/linkstate/linkstate.py +107 -24
- flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
- flwr/server/superlink/linkstate/utils.py +3 -54
- flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
- flwr/simulation/ray_transport/ray_actor.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +3 -2
- flwr/supercore/constant.py +22 -0
- flwr/supercore/object_store/in_memory_object_store.py +0 -4
- flwr/supercore/object_store/object_store_factory.py +26 -6
- flwr/supercore/object_store/sqlite_object_store.py +252 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
- flwr/supercore/sqlite_mixin.py +156 -0
- flwr/supercore/utils.py +20 -0
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +91 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
- flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
- flwr/superlink/servicer/control/control_grpc.py +13 -11
- flwr/superlink/servicer/control/control_servicer.py +152 -60
- flwr/supernode/cli/flower_supernode.py +19 -26
- flwr/supernode/runtime/run_clientapp.py +2 -2
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
- flwr/supernode/start_client_internal.py +17 -9
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/METADATA +1 -1
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/RECORD +107 -96
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- /flwr/{client → clientapp}/client_app.py +0 -0
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
flwr/server/app.py
CHANGED
|
@@ -16,30 +16,26 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import argparse
|
|
19
|
-
import csv
|
|
20
19
|
import importlib.util
|
|
21
20
|
import os
|
|
22
21
|
import subprocess
|
|
23
22
|
import sys
|
|
24
23
|
import threading
|
|
25
24
|
from collections.abc import Sequence
|
|
26
|
-
from logging import
|
|
25
|
+
from logging import INFO, WARN
|
|
27
26
|
from pathlib import Path
|
|
28
27
|
from time import sleep
|
|
29
|
-
from typing import
|
|
28
|
+
from typing import Callable, Optional, TypeVar, cast
|
|
30
29
|
|
|
31
30
|
import grpc
|
|
32
31
|
import yaml
|
|
33
|
-
from cryptography.hazmat.primitives.asymmetric import ec
|
|
34
|
-
from cryptography.hazmat.primitives.serialization import load_ssh_public_key
|
|
35
32
|
|
|
36
33
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
37
34
|
from flwr.common.address import parse_address
|
|
38
35
|
from flwr.common.args import try_obtain_server_certificates
|
|
39
|
-
from flwr.common.auth_plugin import ControlAuthPlugin, ControlAuthzPlugin
|
|
40
36
|
from flwr.common.config import get_flwr_dir
|
|
41
37
|
from flwr.common.constant import (
|
|
42
|
-
|
|
38
|
+
AUTHN_TYPE_YAML_KEY,
|
|
43
39
|
AUTHZ_TYPE_YAML_KEY,
|
|
44
40
|
CLIENT_OCTET,
|
|
45
41
|
CONTROL_API_DEFAULT_SERVER_ADDRESS,
|
|
@@ -53,6 +49,8 @@ from flwr.common.constant import (
|
|
|
53
49
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
54
50
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
55
51
|
TRANSPORT_TYPE_REST,
|
|
52
|
+
AuthnType,
|
|
53
|
+
AuthzType,
|
|
56
54
|
EventLogWriterType,
|
|
57
55
|
ExecPluginType,
|
|
58
56
|
)
|
|
@@ -60,37 +58,43 @@ from flwr.common.event_log_plugin import EventLogWriterPlugin
|
|
|
60
58
|
from flwr.common.exit import ExitCode, flwr_exit, register_signal_handlers
|
|
61
59
|
from flwr.common.grpc import generic_create_grpc_server
|
|
62
60
|
from flwr.common.logger import log
|
|
63
|
-
from flwr.common.
|
|
64
|
-
public_key_to_bytes,
|
|
65
|
-
)
|
|
61
|
+
from flwr.common.version import package_version
|
|
66
62
|
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
|
67
63
|
add_FleetServicer_to_server,
|
|
68
64
|
)
|
|
69
65
|
from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
|
|
70
66
|
from flwr.server.fleet_event_log_interceptor import FleetEventLogInterceptor
|
|
67
|
+
from flwr.supercore.constant import FLWR_IN_MEMORY_DB_NAME
|
|
71
68
|
from flwr.supercore.ffs import FfsFactory
|
|
72
69
|
from flwr.supercore.grpc_health import add_args_health, run_health_server_grpc_no_tls
|
|
73
70
|
from flwr.supercore.object_store import ObjectStoreFactory
|
|
74
71
|
from flwr.superlink.artifact_provider import ArtifactProvider
|
|
72
|
+
from flwr.superlink.auth_plugin import (
|
|
73
|
+
ControlAuthnPlugin,
|
|
74
|
+
ControlAuthzPlugin,
|
|
75
|
+
NoOpControlAuthnPlugin,
|
|
76
|
+
NoOpControlAuthzPlugin,
|
|
77
|
+
)
|
|
75
78
|
from flwr.superlink.servicer.control import run_control_api_grpc
|
|
76
79
|
|
|
77
80
|
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
|
|
78
81
|
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
|
79
|
-
from .superlink.fleet.grpc_rere.
|
|
82
|
+
from .superlink.fleet.grpc_rere.node_auth_server_interceptor import (
|
|
83
|
+
NodeAuthServerInterceptor,
|
|
84
|
+
)
|
|
80
85
|
from .superlink.linkstate import LinkStateFactory
|
|
81
86
|
from .superlink.serverappio.serverappio_grpc import run_serverappio_api_grpc
|
|
82
87
|
from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
|
|
83
88
|
|
|
84
|
-
DATABASE = ":flwr-in-memory-state:"
|
|
85
89
|
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
|
|
86
|
-
P = TypeVar("P",
|
|
90
|
+
P = TypeVar("P", ControlAuthnPlugin, ControlAuthzPlugin)
|
|
87
91
|
|
|
88
92
|
|
|
89
93
|
try:
|
|
90
94
|
from flwr.ee import (
|
|
91
95
|
add_ee_args_superlink,
|
|
92
|
-
|
|
93
|
-
|
|
96
|
+
get_control_authn_ee_plugins,
|
|
97
|
+
get_control_authz_ee_plugins,
|
|
94
98
|
get_control_event_log_writer_plugins,
|
|
95
99
|
get_ee_artifact_provider,
|
|
96
100
|
get_fleet_event_log_writer_plugins,
|
|
@@ -101,14 +105,6 @@ except ImportError:
|
|
|
101
105
|
def add_ee_args_superlink(parser: argparse.ArgumentParser) -> None:
|
|
102
106
|
"""Add EE-specific arguments to the parser."""
|
|
103
107
|
|
|
104
|
-
def get_control_auth_plugins() -> dict[str, type[ControlAuthPlugin]]:
|
|
105
|
-
"""Return all Control API authentication plugins."""
|
|
106
|
-
raise NotImplementedError("No authentication plugins are currently supported.")
|
|
107
|
-
|
|
108
|
-
def get_control_authz_plugins() -> dict[str, type[ControlAuthzPlugin]]:
|
|
109
|
-
"""Return all Control API authorization plugins."""
|
|
110
|
-
raise NotImplementedError("No authorization plugins are currently supported.")
|
|
111
|
-
|
|
112
108
|
def get_control_event_log_writer_plugins() -> dict[str, type[EventLogWriterPlugin]]:
|
|
113
109
|
"""Return all Control API event log writer plugins."""
|
|
114
110
|
raise NotImplementedError(
|
|
@@ -125,6 +121,26 @@ except ImportError:
|
|
|
125
121
|
"No event log writer plugins are currently supported."
|
|
126
122
|
)
|
|
127
123
|
|
|
124
|
+
def get_control_authn_ee_plugins() -> dict[str, type[ControlAuthnPlugin]]:
|
|
125
|
+
"""Return all Control API authentication plugins for EE."""
|
|
126
|
+
return {}
|
|
127
|
+
|
|
128
|
+
def get_control_authz_ee_plugins() -> dict[str, type[ControlAuthzPlugin]]:
|
|
129
|
+
"""Return all Control API authorization plugins for EE."""
|
|
130
|
+
return {}
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def get_control_authn_plugins() -> dict[str, type[ControlAuthnPlugin]]:
|
|
134
|
+
"""Return all Control API authentication plugins."""
|
|
135
|
+
ee_dict: dict[str, type[ControlAuthnPlugin]] = get_control_authn_ee_plugins()
|
|
136
|
+
return ee_dict | {AuthnType.NOOP: NoOpControlAuthnPlugin}
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def get_control_authz_plugins() -> dict[str, type[ControlAuthzPlugin]]:
|
|
140
|
+
"""Return all Control API authorization plugins."""
|
|
141
|
+
ee_dict: dict[str, type[ControlAuthzPlugin]] = get_control_authz_ee_plugins()
|
|
142
|
+
return ee_dict | {AuthzType.NOOP: NoOpControlAuthzPlugin}
|
|
143
|
+
|
|
128
144
|
|
|
129
145
|
# pylint: disable=too-many-branches, too-many-locals, too-many-statements
|
|
130
146
|
def run_superlink() -> None:
|
|
@@ -189,18 +205,24 @@ def run_superlink() -> None:
|
|
|
189
205
|
# Obtain certificates
|
|
190
206
|
certificates = try_obtain_server_certificates(args)
|
|
191
207
|
|
|
192
|
-
# Disable the
|
|
208
|
+
# Disable the account auth TLS check if args.disable_oidc_tls_cert_verification is
|
|
193
209
|
# provided
|
|
194
210
|
verify_tls_cert = not getattr(args, "disable_oidc_tls_cert_verification", None)
|
|
195
211
|
|
|
196
|
-
|
|
212
|
+
authn_plugin: Optional[ControlAuthnPlugin] = None
|
|
197
213
|
authz_plugin: Optional[ControlAuthzPlugin] = None
|
|
198
214
|
event_log_plugin: Optional[EventLogWriterPlugin] = None
|
|
199
|
-
# Load the auth plugin if the args.
|
|
215
|
+
# Load the auth plugin if the args.account_auth_config is provided
|
|
200
216
|
if cfg_path := getattr(args, "user_auth_config", None):
|
|
201
|
-
|
|
202
|
-
|
|
217
|
+
log(
|
|
218
|
+
WARN,
|
|
219
|
+
"The `--user-auth-config` flag is deprecated and will be removed in a "
|
|
220
|
+
"future release. Please use `--account-auth-config` instead.",
|
|
203
221
|
)
|
|
222
|
+
args.account_auth_config = cfg_path
|
|
223
|
+
cfg_path = getattr(args, "account_auth_config", None)
|
|
224
|
+
authn_plugin, authz_plugin = _load_control_auth_plugins(cfg_path, verify_tls_cert)
|
|
225
|
+
if cfg_path is not None:
|
|
204
226
|
# Enable event logging if the args.enable_event_log is True
|
|
205
227
|
if args.enable_event_log:
|
|
206
228
|
event_log_plugin = _try_obtain_control_event_log_writer_plugin()
|
|
@@ -211,6 +233,54 @@ def run_superlink() -> None:
|
|
|
211
233
|
log(WARN, "The `--artifact-provider-config` flag is highly experimental.")
|
|
212
234
|
artifact_provider = get_ee_artifact_provider(cfg_path)
|
|
213
235
|
|
|
236
|
+
# Check for incompatible args with SuperNode authentication
|
|
237
|
+
enable_supernode_auth: bool = args.enable_supernode_auth
|
|
238
|
+
if enable_supernode_auth:
|
|
239
|
+
if args.insecure:
|
|
240
|
+
url_v = f"https://flower.ai/docs/framework/v{package_version}/en/"
|
|
241
|
+
page = "how-to-authenticate-supernodes.html"
|
|
242
|
+
flwr_exit(
|
|
243
|
+
ExitCode.SUPERLINK_INVALID_ARGS,
|
|
244
|
+
"The `--enable-supernode-auth` flag requires encrypted TLS "
|
|
245
|
+
"communications. Please provide TLS certificates using the "
|
|
246
|
+
"`--ssl-certfile`, `--ssl-keyfile` and `--ssl-ca-certfile` "
|
|
247
|
+
"arguments to your SuperLink. Please refer to the Flower "
|
|
248
|
+
f"documentation for more information: {url_v}{page}",
|
|
249
|
+
)
|
|
250
|
+
if args.fleet_api_type != TRANSPORT_TYPE_GRPC_RERE:
|
|
251
|
+
flwr_exit(
|
|
252
|
+
ExitCode.SUPERLINK_INVALID_ARGS,
|
|
253
|
+
"The `--enable-supernode-auth` flag is only supported "
|
|
254
|
+
"with the gRPC-rere Fleet API transport. Please set "
|
|
255
|
+
f"`--fleet-api-type` to `{TRANSPORT_TYPE_GRPC_RERE}`.",
|
|
256
|
+
)
|
|
257
|
+
if args.simulation:
|
|
258
|
+
log(
|
|
259
|
+
WARN,
|
|
260
|
+
"SuperNode authentication is not applicable with the simulation, "
|
|
261
|
+
"runtime as no SuperNodes can connect to this SuperLink. "
|
|
262
|
+
"Proceeding...",
|
|
263
|
+
)
|
|
264
|
+
# If supernode authentication is disabled, warn users
|
|
265
|
+
else:
|
|
266
|
+
log(
|
|
267
|
+
WARN,
|
|
268
|
+
"SuperNode authentication is disabled. The SuperLink will accept "
|
|
269
|
+
"connections from any SuperNode.",
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
if args.auth_list_public_keys:
|
|
273
|
+
url_v = f"https://flower.ai/docs/framework/v{package_version}/en/"
|
|
274
|
+
page = "how-to-authenticate-supernodes.html"
|
|
275
|
+
flwr_exit(
|
|
276
|
+
ExitCode.SUPERLINK_INVALID_ARGS,
|
|
277
|
+
"The `--auth-list-public-keys` "
|
|
278
|
+
"argument is no longer supported. To enable SuperNode authentication, "
|
|
279
|
+
"use the `--enable-supernode-auth` flag and use the Flower CLI to register "
|
|
280
|
+
"SuperNodes by supplying their public keys. Please refer"
|
|
281
|
+
f" to the Flower documentation for more information: {url_v}{page}",
|
|
282
|
+
)
|
|
283
|
+
|
|
214
284
|
# Initialize StateFactory
|
|
215
285
|
state_factory = LinkStateFactory(args.database)
|
|
216
286
|
|
|
@@ -218,7 +288,7 @@ def run_superlink() -> None:
|
|
|
218
288
|
ffs_factory = FfsFactory(args.storage_dir)
|
|
219
289
|
|
|
220
290
|
# Initialize ObjectStoreFactory
|
|
221
|
-
objectstore_factory = ObjectStoreFactory()
|
|
291
|
+
objectstore_factory = ObjectStoreFactory(args.database)
|
|
222
292
|
|
|
223
293
|
# Start Control API
|
|
224
294
|
is_simulation = args.simulation
|
|
@@ -229,7 +299,7 @@ def run_superlink() -> None:
|
|
|
229
299
|
objectstore_factory=objectstore_factory,
|
|
230
300
|
certificates=certificates,
|
|
231
301
|
is_simulation=is_simulation,
|
|
232
|
-
|
|
302
|
+
authn_plugin=authn_plugin,
|
|
233
303
|
authz_plugin=authz_plugin,
|
|
234
304
|
event_log_plugin=event_log_plugin,
|
|
235
305
|
artifact_provider=artifact_provider,
|
|
@@ -306,22 +376,8 @@ def run_superlink() -> None:
|
|
|
306
376
|
fleet_thread.start()
|
|
307
377
|
bckg_threads.append(fleet_thread)
|
|
308
378
|
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
309
|
-
node_public_keys = _try_load_public_keys_node_authentication(args)
|
|
310
|
-
auto_auth = True
|
|
311
|
-
if node_public_keys is not None:
|
|
312
|
-
auto_auth = False
|
|
313
|
-
state = state_factory.state()
|
|
314
|
-
state.clear_supernode_auth_keys()
|
|
315
|
-
state.store_node_public_keys(node_public_keys)
|
|
316
|
-
log(
|
|
317
|
-
INFO,
|
|
318
|
-
"Node authentication enabled with %d known public keys",
|
|
319
|
-
len(node_public_keys),
|
|
320
|
-
)
|
|
321
|
-
else:
|
|
322
|
-
log(DEBUG, "Automatic node authentication enabled")
|
|
323
379
|
|
|
324
|
-
interceptors = [
|
|
380
|
+
interceptors = [NodeAuthServerInterceptor(state_factory)]
|
|
325
381
|
if getattr(args, "enable_event_log", None):
|
|
326
382
|
fleet_log_plugin = _try_obtain_fleet_event_log_writer_plugin()
|
|
327
383
|
if fleet_log_plugin is not None:
|
|
@@ -333,6 +389,7 @@ def run_superlink() -> None:
|
|
|
333
389
|
state_factory=state_factory,
|
|
334
390
|
ffs_factory=ffs_factory,
|
|
335
391
|
objectstore_factory=objectstore_factory,
|
|
392
|
+
enable_supernode_auth=enable_supernode_auth,
|
|
336
393
|
certificates=certificates,
|
|
337
394
|
interceptors=interceptors,
|
|
338
395
|
)
|
|
@@ -400,55 +457,21 @@ def _format_address(address: str) -> tuple[str, str, int]:
|
|
|
400
457
|
return (f"[{host}]:{port}" if is_v6 else f"{host}:{port}", host, port)
|
|
401
458
|
|
|
402
459
|
|
|
403
|
-
def
|
|
404
|
-
|
|
405
|
-
) ->
|
|
406
|
-
"""Return a set of node public keys."""
|
|
407
|
-
if args.auth_superlink_private_key or args.auth_superlink_public_key:
|
|
408
|
-
log(
|
|
409
|
-
WARN,
|
|
410
|
-
"The `--auth-superlink-private-key` and `--auth-superlink-public-key` "
|
|
411
|
-
"arguments are deprecated and will be removed in a future release. Node "
|
|
412
|
-
"authentication no longer requires these arguments.",
|
|
413
|
-
)
|
|
414
|
-
|
|
415
|
-
if not args.auth_list_public_keys:
|
|
416
|
-
return None
|
|
417
|
-
|
|
418
|
-
node_keys_file_path = Path(args.auth_list_public_keys)
|
|
419
|
-
if not node_keys_file_path.exists():
|
|
420
|
-
sys.exit(
|
|
421
|
-
"The provided path to the known public keys CSV file does not exist: "
|
|
422
|
-
f"{node_keys_file_path}. "
|
|
423
|
-
"Please provide the CSV file path containing known public keys "
|
|
424
|
-
"to '--auth-list-public-keys'."
|
|
425
|
-
)
|
|
426
|
-
|
|
427
|
-
node_public_keys: set[bytes] = set()
|
|
428
|
-
|
|
429
|
-
with open(node_keys_file_path, newline="", encoding="utf-8") as csvfile:
|
|
430
|
-
reader = csv.reader(csvfile)
|
|
431
|
-
for row in reader:
|
|
432
|
-
for element in row:
|
|
433
|
-
public_key = load_ssh_public_key(element.encode())
|
|
434
|
-
if isinstance(public_key, ec.EllipticCurvePublicKey):
|
|
435
|
-
node_public_keys.add(public_key_to_bytes(public_key))
|
|
436
|
-
else:
|
|
437
|
-
sys.exit(
|
|
438
|
-
"Error: Unable to parse the public keys in the CSV "
|
|
439
|
-
"file. Please ensure that the CSV file path points to a valid "
|
|
440
|
-
"known SSH public keys files and try again."
|
|
441
|
-
)
|
|
442
|
-
return node_public_keys
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
def _try_obtain_control_auth_plugins(
|
|
446
|
-
config_path: Path, verify_tls_cert: bool
|
|
447
|
-
) -> tuple[ControlAuthPlugin, ControlAuthzPlugin]:
|
|
460
|
+
def _load_control_auth_plugins(
|
|
461
|
+
config_path: Optional[str], verify_tls_cert: bool
|
|
462
|
+
) -> tuple[ControlAuthnPlugin, ControlAuthzPlugin]:
|
|
448
463
|
"""Obtain Control API authentication and authorization plugins."""
|
|
464
|
+
# Load NoOp plugins if no config path is provided
|
|
465
|
+
if config_path is None:
|
|
466
|
+
config_path = ""
|
|
467
|
+
config = {
|
|
468
|
+
"authentication": {AUTHN_TYPE_YAML_KEY: AuthnType.NOOP},
|
|
469
|
+
"authorization": {AUTHZ_TYPE_YAML_KEY: AuthzType.NOOP},
|
|
470
|
+
}
|
|
449
471
|
# Load YAML file
|
|
450
|
-
|
|
451
|
-
|
|
472
|
+
else:
|
|
473
|
+
with Path(config_path).open("r", encoding="utf-8") as file:
|
|
474
|
+
config = yaml.safe_load(file)
|
|
452
475
|
|
|
453
476
|
def _load_plugin(
|
|
454
477
|
section: str, yaml_key: str, loader: Callable[[], dict[str, type[P]]]
|
|
@@ -458,9 +481,7 @@ def _try_obtain_control_auth_plugins(
|
|
|
458
481
|
try:
|
|
459
482
|
plugins: dict[str, type[P]] = loader()
|
|
460
483
|
plugin_cls: type[P] = plugins[auth_plugin_name]
|
|
461
|
-
return plugin_cls(
|
|
462
|
-
user_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
|
|
463
|
-
)
|
|
484
|
+
return plugin_cls(Path(cast(str, config_path)), verify_tls_cert)
|
|
464
485
|
except KeyError:
|
|
465
486
|
if auth_plugin_name:
|
|
466
487
|
sys.exit(
|
|
@@ -468,14 +489,22 @@ def _try_obtain_control_auth_plugins(
|
|
|
468
489
|
f"Please provide a valid {section} type in the configuration."
|
|
469
490
|
)
|
|
470
491
|
sys.exit(f"No {section} type is provided in the configuration.")
|
|
471
|
-
|
|
472
|
-
|
|
492
|
+
|
|
493
|
+
# Warn deprecated auth_type key
|
|
494
|
+
if authn_type := config["authentication"].pop("auth_type", None):
|
|
495
|
+
log(
|
|
496
|
+
WARN,
|
|
497
|
+
"The `auth_type` key in the authentication configuration is deprecated. "
|
|
498
|
+
"Use `%s` instead.",
|
|
499
|
+
AUTHN_TYPE_YAML_KEY,
|
|
500
|
+
)
|
|
501
|
+
config["authentication"][AUTHN_TYPE_YAML_KEY] = authn_type
|
|
473
502
|
|
|
474
503
|
# Load authentication plugin
|
|
475
|
-
|
|
504
|
+
authn_plugin = _load_plugin(
|
|
476
505
|
section="authentication",
|
|
477
|
-
yaml_key=
|
|
478
|
-
loader=
|
|
506
|
+
yaml_key=AUTHN_TYPE_YAML_KEY,
|
|
507
|
+
loader=get_control_authn_plugins,
|
|
479
508
|
)
|
|
480
509
|
|
|
481
510
|
# Load authorization plugin
|
|
@@ -485,7 +514,7 @@ def _try_obtain_control_auth_plugins(
|
|
|
485
514
|
loader=get_control_authz_plugins,
|
|
486
515
|
)
|
|
487
516
|
|
|
488
|
-
return
|
|
517
|
+
return authn_plugin, authz_plugin
|
|
489
518
|
|
|
490
519
|
|
|
491
520
|
def _try_obtain_control_event_log_writer_plugin() -> Optional[EventLogWriterPlugin]:
|
|
@@ -521,6 +550,7 @@ def _run_fleet_api_grpc_rere( # pylint: disable=R0913, R0917
|
|
|
521
550
|
state_factory: LinkStateFactory,
|
|
522
551
|
ffs_factory: FfsFactory,
|
|
523
552
|
objectstore_factory: ObjectStoreFactory,
|
|
553
|
+
enable_supernode_auth: bool,
|
|
524
554
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
525
555
|
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
526
556
|
) -> grpc.Server:
|
|
@@ -530,6 +560,7 @@ def _run_fleet_api_grpc_rere( # pylint: disable=R0913, R0917
|
|
|
530
560
|
state_factory=state_factory,
|
|
531
561
|
ffs_factory=ffs_factory,
|
|
532
562
|
objectstore_factory=objectstore_factory,
|
|
563
|
+
enable_supernode_auth=enable_supernode_auth,
|
|
533
564
|
)
|
|
534
565
|
fleet_add_servicer_to_server_fn = add_FleetServicer_to_server
|
|
535
566
|
fleet_grpc_server = generic_create_grpc_server(
|
|
@@ -548,6 +579,7 @@ def _run_fleet_api_grpc_rere( # pylint: disable=R0913, R0917
|
|
|
548
579
|
return fleet_grpc_server
|
|
549
580
|
|
|
550
581
|
|
|
582
|
+
# pylint: disable=R0913, R0917
|
|
551
583
|
def _run_fleet_api_grpc_adapter(
|
|
552
584
|
address: str,
|
|
553
585
|
state_factory: LinkStateFactory,
|
|
@@ -561,6 +593,7 @@ def _run_fleet_api_grpc_adapter(
|
|
|
561
593
|
state_factory=state_factory,
|
|
562
594
|
ffs_factory=ffs_factory,
|
|
563
595
|
objectstore_factory=objectstore_factory,
|
|
596
|
+
enable_supernode_auth=False,
|
|
564
597
|
)
|
|
565
598
|
fleet_add_servicer_to_server_fn = add_GrpcAdapterServicer_to_server
|
|
566
599
|
fleet_grpc_server = generic_create_grpc_server(
|
|
@@ -691,11 +724,9 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
691
724
|
parser.add_argument(
|
|
692
725
|
"--database",
|
|
693
726
|
help="A string representing the path to the database "
|
|
694
|
-
"file that will be opened.
|
|
695
|
-
"will open a connection to a database that is in RAM, "
|
|
696
|
-
"instead of on disk. If nothing is provided, "
|
|
727
|
+
"file that will be opened. If nothing is provided, "
|
|
697
728
|
"Flower will just create a state in memory.",
|
|
698
|
-
default=
|
|
729
|
+
default=FLWR_IN_MEMORY_DB_NAME,
|
|
699
730
|
)
|
|
700
731
|
parser.add_argument(
|
|
701
732
|
"--storage-dir",
|
|
@@ -705,18 +736,12 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
705
736
|
parser.add_argument(
|
|
706
737
|
"--auth-list-public-keys",
|
|
707
738
|
type=str,
|
|
708
|
-
help="A CSV file (as a path str) containing a list of known public "
|
|
709
|
-
"keys to enable authentication.",
|
|
710
|
-
)
|
|
711
|
-
parser.add_argument(
|
|
712
|
-
"--auth-superlink-private-key",
|
|
713
|
-
type=str,
|
|
714
739
|
help="This argument is deprecated and will be removed in a future release.",
|
|
715
740
|
)
|
|
716
741
|
parser.add_argument(
|
|
717
|
-
"--
|
|
718
|
-
|
|
719
|
-
help="
|
|
742
|
+
"--enable-supernode-auth",
|
|
743
|
+
action="store_true",
|
|
744
|
+
help="Enable supernode authentication.",
|
|
720
745
|
)
|
|
721
746
|
|
|
722
747
|
|
|
@@ -33,10 +33,12 @@ from flwr.common.version import package_name, package_version
|
|
|
33
33
|
from flwr.proto import grpcadapter_pb2_grpc # pylint: disable=E0611
|
|
34
34
|
from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
|
|
35
35
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
36
|
-
|
|
37
|
-
|
|
36
|
+
ActivateNodeRequest,
|
|
37
|
+
DeactivateNodeRequest,
|
|
38
38
|
PullMessagesRequest,
|
|
39
39
|
PushMessagesRequest,
|
|
40
|
+
RegisterNodeFleetRequest,
|
|
41
|
+
UnregisterNodeFleetRequest,
|
|
40
42
|
)
|
|
41
43
|
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
42
44
|
from flwr.proto.heartbeat_pb2 import SendNodeHeartbeatRequest # pylint: disable=E0611
|
|
@@ -77,15 +79,23 @@ def _handle(
|
|
|
77
79
|
class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer, FleetServicer):
|
|
78
80
|
"""Fleet API via GrpcAdapter servicer."""
|
|
79
81
|
|
|
80
|
-
def SendReceive( # pylint: disable=too-many-return-statements
|
|
82
|
+
def SendReceive( # pylint: disable=too-many-return-statements, too-many-branches
|
|
81
83
|
self, request: MessageContainer, context: grpc.ServicerContext
|
|
82
84
|
) -> MessageContainer:
|
|
83
85
|
"""."""
|
|
84
86
|
log(DEBUG, "GrpcAdapterServicer.SendReceive")
|
|
85
|
-
if request.grpc_message_name ==
|
|
86
|
-
return _handle(
|
|
87
|
-
|
|
88
|
-
|
|
87
|
+
if request.grpc_message_name == RegisterNodeFleetRequest.__qualname__:
|
|
88
|
+
return _handle(
|
|
89
|
+
request, context, RegisterNodeFleetRequest, self.RegisterNode
|
|
90
|
+
)
|
|
91
|
+
if request.grpc_message_name == ActivateNodeRequest.__qualname__:
|
|
92
|
+
return _handle(request, context, ActivateNodeRequest, self.ActivateNode)
|
|
93
|
+
if request.grpc_message_name == DeactivateNodeRequest.__qualname__:
|
|
94
|
+
return _handle(request, context, DeactivateNodeRequest, self.DeactivateNode)
|
|
95
|
+
if request.grpc_message_name == UnregisterNodeFleetRequest.__qualname__:
|
|
96
|
+
return _handle(
|
|
97
|
+
request, context, UnregisterNodeFleetRequest, self.UnregisterNode
|
|
98
|
+
)
|
|
89
99
|
if request.grpc_message_name == SendNodeHeartbeatRequest.__qualname__:
|
|
90
100
|
return _handle(
|
|
91
101
|
request, context, SendNodeHeartbeatRequest, self.SendNodeHeartbeat
|