flwr 1.22.0__py3-none-any.whl → 1.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +15 -1
- flwr/cli/auth_plugin/__init__.py +15 -6
- flwr/cli/auth_plugin/auth_plugin.py +95 -0
- flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
- flwr/cli/build.py +118 -47
- flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +34 -23
- flwr/cli/ls.py +13 -9
- flwr/cli/new/new.py +187 -35
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +2 -2
- flwr/cli/run/run.py +11 -7
- flwr/cli/stop.py +2 -2
- flwr/cli/supernode/__init__.py +25 -0
- flwr/cli/supernode/ls.py +260 -0
- flwr/cli/supernode/register.py +185 -0
- flwr/cli/supernode/unregister.py +138 -0
- flwr/cli/utils.py +92 -69
- flwr/client/__init__.py +2 -1
- flwr/client/grpc_adapter_client/connection.py +6 -8
- flwr/client/grpc_rere_client/connection.py +59 -31
- flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
- flwr/client/rest_client/connection.py +82 -37
- flwr/clientapp/__init__.py +1 -2
- flwr/{client/clientapp → clientapp}/utils.py +1 -1
- flwr/common/constant.py +53 -13
- flwr/common/exit/exit_code.py +20 -10
- flwr/common/inflatable_utils.py +10 -10
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +10 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
- flwr/common/serde.py +4 -2
- flwr/common/typing.py +7 -6
- flwr/compat/client/app.py +1 -1
- flwr/compat/client/grpc_client/connection.py +2 -2
- flwr/proto/control_pb2.py +48 -35
- flwr/proto/control_pb2.pyi +71 -5
- flwr/proto/control_pb2_grpc.py +102 -0
- flwr/proto/control_pb2_grpc.pyi +39 -0
- flwr/proto/fab_pb2.py +11 -7
- flwr/proto/fab_pb2.pyi +21 -1
- flwr/proto/fleet_pb2.py +31 -23
- flwr/proto/fleet_pb2.pyi +63 -23
- flwr/proto/fleet_pb2_grpc.py +98 -28
- flwr/proto/fleet_pb2_grpc.pyi +45 -13
- flwr/proto/node_pb2.py +3 -1
- flwr/proto/node_pb2.pyi +48 -0
- flwr/server/app.py +139 -114
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
- flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
- flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +18 -5
- flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
- flwr/server/superlink/linkstate/linkstate.py +107 -24
- flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
- flwr/server/superlink/linkstate/utils.py +3 -54
- flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
- flwr/server/utils/validator.py +2 -3
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
- flwr/simulation/ray_transport/ray_actor.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +3 -2
- flwr/supercore/constant.py +22 -0
- flwr/supercore/object_store/in_memory_object_store.py +0 -4
- flwr/supercore/object_store/object_store_factory.py +26 -6
- flwr/supercore/object_store/sqlite_object_store.py +252 -0
- flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
- flwr/supercore/primitives/asymmetric.py +117 -0
- flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
- flwr/supercore/sqlite_mixin.py +156 -0
- flwr/supercore/utils.py +20 -0
- flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
- flwr/superlink/auth_plugin/auth_plugin.py +91 -0
- flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
- flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
- flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
- flwr/superlink/servicer/control/control_grpc.py +13 -11
- flwr/superlink/servicer/control/control_servicer.py +152 -60
- flwr/supernode/cli/flower_supernode.py +19 -26
- flwr/supernode/runtime/run_clientapp.py +2 -2
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
- flwr/supernode/start_client_internal.py +17 -9
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/METADATA +1 -1
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/RECORD +107 -96
- flwr/common/auth_plugin/auth_plugin.py +0 -149
- /flwr/{client → clientapp}/client_app.py +0 -0
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
- {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
flwr/cli/utils.py
CHANGED
|
@@ -26,16 +26,18 @@ from typing import Any, Callable, Optional, Union, cast
|
|
|
26
26
|
import grpc
|
|
27
27
|
import typer
|
|
28
28
|
|
|
29
|
-
from flwr.cli.cli_user_auth_interceptor import CliUserAuthInterceptor
|
|
30
|
-
from flwr.common.auth_plugin import CliAuthPlugin
|
|
31
29
|
from flwr.common.constant import (
|
|
32
|
-
|
|
30
|
+
AUTHN_TYPE_JSON_KEY,
|
|
33
31
|
CREDENTIALS_DIR,
|
|
34
32
|
FLWR_DIR,
|
|
33
|
+
NO_ACCOUNT_AUTH_MESSAGE,
|
|
35
34
|
NO_ARTIFACT_PROVIDER_MESSAGE,
|
|
36
|
-
|
|
35
|
+
NODE_NOT_FOUND_MESSAGE,
|
|
36
|
+
PUBLIC_KEY_ALREADY_IN_USE_MESSAGE,
|
|
37
|
+
PUBLIC_KEY_NOT_VALID,
|
|
37
38
|
PULL_UNFINISHED_RUN_MESSAGE,
|
|
38
39
|
RUN_ID_NOT_FOUND_MESSAGE,
|
|
40
|
+
AuthnType,
|
|
39
41
|
)
|
|
40
42
|
from flwr.common.grpc import (
|
|
41
43
|
GRPC_MAX_MESSAGE_LENGTH,
|
|
@@ -43,7 +45,8 @@ from flwr.common.grpc import (
|
|
|
43
45
|
on_channel_state_change,
|
|
44
46
|
)
|
|
45
47
|
|
|
46
|
-
from .auth_plugin import
|
|
48
|
+
from .auth_plugin import CliAuthPlugin, get_cli_plugin_class
|
|
49
|
+
from .cli_account_auth_interceptor import CliAccountAuthInterceptor
|
|
47
50
|
from .config_utils import validate_certificate_in_federation_config
|
|
48
51
|
|
|
49
52
|
|
|
@@ -166,8 +169,8 @@ def get_sha256_hash(file_path_or_int: Union[Path, int]) -> str:
|
|
|
166
169
|
return sha256.hexdigest()
|
|
167
170
|
|
|
168
171
|
|
|
169
|
-
def
|
|
170
|
-
"""Return the path to the
|
|
172
|
+
def get_account_auth_config_path(root_dir: Path, federation: str) -> Path:
|
|
173
|
+
"""Return the path to the account auth config file.
|
|
171
174
|
|
|
172
175
|
Additionally, a `.gitignore` file will be created in the Flower directory to
|
|
173
176
|
include the `.credentials` folder to be excluded from git. If the `.gitignore`
|
|
@@ -217,71 +220,68 @@ def get_user_auth_config_path(root_dir: Path, federation: str) -> Path:
|
|
|
217
220
|
return credentials_dir / f"{federation}.json"
|
|
218
221
|
|
|
219
222
|
|
|
220
|
-
def
|
|
223
|
+
def account_auth_enabled(federation_config: dict[str, Any]) -> bool:
|
|
224
|
+
"""Check if account authentication is enabled in the federation config."""
|
|
225
|
+
enabled: bool = federation_config.get("enable-user-auth", False)
|
|
226
|
+
enabled |= federation_config.get("enable-account-auth", False)
|
|
227
|
+
if "enable-user-auth" in federation_config:
|
|
228
|
+
typer.secho(
|
|
229
|
+
"`enable-user-auth` is deprecated and will be removed in a future "
|
|
230
|
+
"release. Please use `enable-account-auth` instead.",
|
|
231
|
+
fg=typer.colors.YELLOW,
|
|
232
|
+
bold=True,
|
|
233
|
+
)
|
|
234
|
+
return enabled
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def retrieve_authn_type(config_path: Path) -> str:
|
|
238
|
+
"""Retrieve the auth type from the config file or return NOOP if not found."""
|
|
239
|
+
try:
|
|
240
|
+
with config_path.open("r", encoding="utf-8") as file:
|
|
241
|
+
json_file = json.load(file)
|
|
242
|
+
authn_type: str = json_file[AUTHN_TYPE_JSON_KEY]
|
|
243
|
+
return authn_type
|
|
244
|
+
except (FileNotFoundError, KeyError):
|
|
245
|
+
return AuthnType.NOOP
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def load_cli_auth_plugin(
|
|
221
249
|
root_dir: Path,
|
|
222
250
|
federation: str,
|
|
223
251
|
federation_config: dict[str, Any],
|
|
224
|
-
|
|
225
|
-
) ->
|
|
226
|
-
"""Load the CLI-side
|
|
227
|
-
#
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
with config_path.open("r", encoding="utf-8") as file:
|
|
238
|
-
json_file = json.load(file)
|
|
239
|
-
auth_type = json_file[AUTH_TYPE_JSON_KEY]
|
|
240
|
-
except (FileNotFoundError, KeyError):
|
|
241
|
-
typer.secho(
|
|
242
|
-
"❌ Missing or invalid credentials for user authentication. "
|
|
243
|
-
"Please run `flwr login` to authenticate.",
|
|
244
|
-
fg=typer.colors.RED,
|
|
245
|
-
bold=True,
|
|
246
|
-
)
|
|
247
|
-
raise typer.Exit(code=1) from None
|
|
252
|
+
authn_type: Optional[str] = None,
|
|
253
|
+
) -> CliAuthPlugin:
|
|
254
|
+
"""Load the CLI-side account auth plugin for the given authn type."""
|
|
255
|
+
# Find the path to the account auth config file
|
|
256
|
+
config_path = get_account_auth_config_path(root_dir, federation)
|
|
257
|
+
|
|
258
|
+
# Determine the auth type if not provided
|
|
259
|
+
# Only `flwr login` command can provide `authn_type` explicitly, as it can query the
|
|
260
|
+
# SuperLink for the auth type.
|
|
261
|
+
if authn_type is None:
|
|
262
|
+
authn_type = AuthnType.NOOP
|
|
263
|
+
if account_auth_enabled(federation_config):
|
|
264
|
+
authn_type = retrieve_authn_type(config_path)
|
|
248
265
|
|
|
249
266
|
# Retrieve auth plugin class and instantiate it
|
|
250
267
|
try:
|
|
251
|
-
|
|
252
|
-
auth_plugin_class = all_plugins[auth_type]
|
|
268
|
+
auth_plugin_class = get_cli_plugin_class(authn_type)
|
|
253
269
|
return auth_plugin_class(config_path)
|
|
254
|
-
except
|
|
255
|
-
typer.echo(f"❌ Unknown
|
|
256
|
-
raise typer.Exit(code=1) from None
|
|
257
|
-
except ImportError:
|
|
258
|
-
typer.echo("❌ No authentication plugins are currently supported.")
|
|
270
|
+
except ValueError:
|
|
271
|
+
typer.echo(f"❌ Unknown account authentication type: {authn_type}")
|
|
259
272
|
raise typer.Exit(code=1) from None
|
|
260
273
|
|
|
261
274
|
|
|
262
275
|
def init_channel(
|
|
263
|
-
app: Path, federation_config: dict[str, Any], auth_plugin:
|
|
276
|
+
app: Path, federation_config: dict[str, Any], auth_plugin: CliAuthPlugin
|
|
264
277
|
) -> grpc.Channel:
|
|
265
278
|
"""Initialize gRPC channel to the Control API."""
|
|
266
279
|
insecure, root_certificates_bytes = validate_certificate_in_federation_config(
|
|
267
280
|
app, federation_config
|
|
268
281
|
)
|
|
269
282
|
|
|
270
|
-
#
|
|
271
|
-
|
|
272
|
-
if auth_plugin is not None:
|
|
273
|
-
# Check if TLS is enabled. If not, raise an error
|
|
274
|
-
if insecure:
|
|
275
|
-
typer.secho(
|
|
276
|
-
"❌ User authentication requires TLS to be enabled. "
|
|
277
|
-
"Remove `insecure = true` from the federation configuration.",
|
|
278
|
-
fg=typer.colors.RED,
|
|
279
|
-
bold=True,
|
|
280
|
-
)
|
|
281
|
-
raise typer.Exit(code=1)
|
|
282
|
-
|
|
283
|
-
auth_plugin.load_tokens()
|
|
284
|
-
interceptors.append(CliUserAuthInterceptor(auth_plugin))
|
|
283
|
+
# Load tokens
|
|
284
|
+
auth_plugin.load_tokens()
|
|
285
285
|
|
|
286
286
|
# Create the gRPC channel
|
|
287
287
|
channel = create_channel(
|
|
@@ -289,14 +289,14 @@ def init_channel(
|
|
|
289
289
|
insecure=insecure,
|
|
290
290
|
root_certificates=root_certificates_bytes,
|
|
291
291
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
292
|
-
interceptors=
|
|
292
|
+
interceptors=[CliAccountAuthInterceptor(auth_plugin)],
|
|
293
293
|
)
|
|
294
294
|
channel.subscribe(on_channel_state_change)
|
|
295
295
|
return channel
|
|
296
296
|
|
|
297
297
|
|
|
298
298
|
@contextmanager
|
|
299
|
-
def flwr_cli_grpc_exc_handler() -> Iterator[None]:
|
|
299
|
+
def flwr_cli_grpc_exc_handler() -> Iterator[None]: # pylint: disable=too-many-branches
|
|
300
300
|
"""Context manager to handle specific gRPC errors.
|
|
301
301
|
|
|
302
302
|
It catches grpc.RpcError exceptions with UNAUTHENTICATED, UNIMPLEMENTED,
|
|
@@ -315,9 +315,9 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
|
|
|
315
315
|
)
|
|
316
316
|
raise typer.Exit(code=1) from None
|
|
317
317
|
if e.code() == grpc.StatusCode.UNIMPLEMENTED:
|
|
318
|
-
if e.details() ==
|
|
318
|
+
if e.details() == NO_ACCOUNT_AUTH_MESSAGE: # pylint: disable=E1101
|
|
319
319
|
typer.secho(
|
|
320
|
-
"❌
|
|
320
|
+
"❌ Account authentication is not enabled on this SuperLink.",
|
|
321
321
|
fg=typer.colors.RED,
|
|
322
322
|
bold=True,
|
|
323
323
|
)
|
|
@@ -354,16 +354,21 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
|
|
|
354
354
|
bold=True,
|
|
355
355
|
)
|
|
356
356
|
raise typer.Exit(code=1) from None
|
|
357
|
-
if (
|
|
358
|
-
e.
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
)
|
|
366
|
-
|
|
357
|
+
if e.code() == grpc.StatusCode.NOT_FOUND:
|
|
358
|
+
if e.details() == RUN_ID_NOT_FOUND_MESSAGE: # pylint: disable=E1101
|
|
359
|
+
typer.secho(
|
|
360
|
+
"❌ Run ID not found.",
|
|
361
|
+
fg=typer.colors.RED,
|
|
362
|
+
bold=True,
|
|
363
|
+
)
|
|
364
|
+
raise typer.Exit(code=1) from None
|
|
365
|
+
if e.details() == NODE_NOT_FOUND_MESSAGE: # pylint: disable=E1101
|
|
366
|
+
typer.secho(
|
|
367
|
+
"❌ Node ID not found for this account.",
|
|
368
|
+
fg=typer.colors.RED,
|
|
369
|
+
bold=True,
|
|
370
|
+
)
|
|
371
|
+
raise typer.Exit(code=1) from None
|
|
367
372
|
if e.code() == grpc.StatusCode.FAILED_PRECONDITION:
|
|
368
373
|
if e.details() == PULL_UNFINISHED_RUN_MESSAGE: # pylint: disable=E1101
|
|
369
374
|
typer.secho(
|
|
@@ -373,4 +378,22 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
|
|
|
373
378
|
bold=True,
|
|
374
379
|
)
|
|
375
380
|
raise typer.Exit(code=1) from None
|
|
381
|
+
if (
|
|
382
|
+
e.details() == PUBLIC_KEY_ALREADY_IN_USE_MESSAGE
|
|
383
|
+
): # pylint: disable=E1101
|
|
384
|
+
typer.secho(
|
|
385
|
+
"❌ The provided public key is already in use by another "
|
|
386
|
+
"SuperNode.",
|
|
387
|
+
fg=typer.colors.RED,
|
|
388
|
+
bold=True,
|
|
389
|
+
)
|
|
390
|
+
raise typer.Exit(code=1) from None
|
|
391
|
+
if e.details() == PUBLIC_KEY_NOT_VALID: # pylint: disable=E1101
|
|
392
|
+
typer.secho(
|
|
393
|
+
"❌ The provided public key is invalid. Please provide a valid "
|
|
394
|
+
"NIST EC public key.",
|
|
395
|
+
fg=typer.colors.RED,
|
|
396
|
+
bold=True,
|
|
397
|
+
)
|
|
398
|
+
raise typer.Exit(code=1) from None
|
|
376
399
|
raise
|
flwr/client/__init__.py
CHANGED
|
@@ -15,10 +15,11 @@
|
|
|
15
15
|
"""Flower client."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from flwr.clientapp import ClientApp
|
|
19
|
+
|
|
18
20
|
from ..compat.client.app import start_client as start_client # Deprecated
|
|
19
21
|
from ..compat.client.app import start_numpy_client as start_numpy_client # Deprecated
|
|
20
22
|
from .client import Client as Client
|
|
21
|
-
from .client_app import ClientApp as ClientApp
|
|
22
23
|
from .numpy_client import NumPyClient as NumPyClient
|
|
23
24
|
from .typing import ClientFn as ClientFn
|
|
24
25
|
from .typing import ClientFnExt as ClientFnExt
|
|
@@ -44,10 +44,9 @@ def grpc_adapter( # pylint: disable=R0913,too-many-positional-arguments
|
|
|
44
44
|
] = None,
|
|
45
45
|
) -> Iterator[
|
|
46
46
|
tuple[
|
|
47
|
+
int,
|
|
47
48
|
Callable[[], Optional[tuple[Message, ObjectTree]]],
|
|
48
49
|
Callable[[Message, ObjectTree], set[str]],
|
|
49
|
-
Callable[[], Optional[int]],
|
|
50
|
-
Callable[[], None],
|
|
51
50
|
Callable[[int], Run],
|
|
52
51
|
Callable[[str, int], Fab],
|
|
53
52
|
Callable[[int, str], bytes],
|
|
@@ -77,22 +76,21 @@ def grpc_adapter( # pylint: disable=R0913,too-many-positional-arguments
|
|
|
77
76
|
connection using the certificates will be established to an SSL-enabled
|
|
78
77
|
Flower server. Bytes won't work for the REST API.
|
|
79
78
|
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
|
|
80
|
-
|
|
79
|
+
SuperNode authentication is not supported for this transport type.
|
|
81
80
|
|
|
82
81
|
Returns
|
|
83
82
|
-------
|
|
83
|
+
node_id : int
|
|
84
84
|
receive : Callable[[], Optional[tuple[Message, ObjectTree]]]
|
|
85
85
|
send : Callable[[Message, ObjectTree], set[str]]
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
get_run : Optional[Callable]
|
|
89
|
-
get_fab : Optional[Callable]
|
|
86
|
+
get_run : Callable[[int], Run]
|
|
87
|
+
get_fab : Callable[[str, int], Fab]
|
|
90
88
|
pull_object : Callable[[str], bytes]
|
|
91
89
|
push_object : Callable[[str, bytes], None]
|
|
92
90
|
confirm_message_received : Callable[[str], None]
|
|
93
91
|
"""
|
|
94
92
|
if authentication_keys is not None:
|
|
95
|
-
log(ERROR, "
|
|
93
|
+
log(ERROR, "SuperNode authentication is not supported for this transport type.")
|
|
96
94
|
with grpc_request_response(
|
|
97
95
|
server_address=server_address,
|
|
98
96
|
insecure=insecure,
|
|
@@ -19,7 +19,7 @@ from collections.abc import Iterator, Sequence
|
|
|
19
19
|
from contextlib import contextmanager
|
|
20
20
|
from logging import ERROR
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import Callable, Optional, Union
|
|
22
|
+
from typing import Callable, Optional, Union
|
|
23
23
|
|
|
24
24
|
import grpc
|
|
25
25
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -36,19 +36,24 @@ from flwr.common.inflatable_protobuf_utils import (
|
|
|
36
36
|
from flwr.common.logger import log
|
|
37
37
|
from flwr.common.message import Message, remove_content_from_message
|
|
38
38
|
from flwr.common.retry_invoker import RetryInvoker, _wrap_stub
|
|
39
|
-
from flwr.common.
|
|
40
|
-
|
|
39
|
+
from flwr.common.serde import (
|
|
40
|
+
fab_from_proto,
|
|
41
|
+
message_from_proto,
|
|
42
|
+
message_to_proto,
|
|
43
|
+
run_from_proto,
|
|
41
44
|
)
|
|
42
|
-
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
|
43
45
|
from flwr.common.typing import Fab, Run
|
|
44
46
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
45
47
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
46
|
-
|
|
47
|
-
|
|
48
|
+
ActivateNodeRequest,
|
|
49
|
+
ActivateNodeResponse,
|
|
50
|
+
DeactivateNodeRequest,
|
|
48
51
|
PullMessagesRequest,
|
|
49
52
|
PullMessagesResponse,
|
|
50
53
|
PushMessagesRequest,
|
|
51
54
|
PushMessagesResponse,
|
|
55
|
+
RegisterNodeFleetRequest,
|
|
56
|
+
UnregisterNodeFleetRequest,
|
|
52
57
|
)
|
|
53
58
|
from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
|
|
54
59
|
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
@@ -58,9 +63,10 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
|
58
63
|
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
|
59
64
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
60
65
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
66
|
+
from flwr.supercore.primitives.asymmetric import generate_key_pairs, public_key_to_bytes
|
|
61
67
|
|
|
62
|
-
from .client_interceptor import AuthenticateClientInterceptor
|
|
63
68
|
from .grpc_adapter import GrpcAdapter
|
|
69
|
+
from .node_auth_client_interceptor import NodeAuthClientInterceptor
|
|
64
70
|
|
|
65
71
|
|
|
66
72
|
@contextmanager
|
|
@@ -76,10 +82,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
76
82
|
adapter_cls: Optional[Union[type[FleetStub], type[GrpcAdapter]]] = None,
|
|
77
83
|
) -> Iterator[
|
|
78
84
|
tuple[
|
|
85
|
+
int,
|
|
79
86
|
Callable[[], Optional[tuple[Message, ObjectTree]]],
|
|
80
87
|
Callable[[Message, ObjectTree], set[str]],
|
|
81
|
-
Callable[[], Optional[int]],
|
|
82
|
-
Callable[[], None],
|
|
83
88
|
Callable[[int], Run],
|
|
84
89
|
Callable[[str, int], Fab],
|
|
85
90
|
Callable[[int, str], bytes],
|
|
@@ -122,11 +127,11 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
122
127
|
|
|
123
128
|
Returns
|
|
124
129
|
-------
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
+
node_id : int
|
|
131
|
+
receive : Callable[[], Optional[tuple[Message, ObjectTree]]]
|
|
132
|
+
send : Callable[[Message, ObjectTree], set[str]]
|
|
133
|
+
get_run : Callable[[int], Run]
|
|
134
|
+
get_fab : Callable[[str, int], Fab]
|
|
130
135
|
pull_object : Callable[[str], bytes]
|
|
131
136
|
push_object : Callable[[str, bytes], None]
|
|
132
137
|
confirm_message_received : Callable[[str], None]
|
|
@@ -135,13 +140,16 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
135
140
|
root_certificates = Path(root_certificates).read_bytes()
|
|
136
141
|
|
|
137
142
|
# Automatic node auth: generate keys if user didn't provide any
|
|
143
|
+
self_registered = False
|
|
138
144
|
if authentication_keys is None:
|
|
145
|
+
self_registered = True
|
|
139
146
|
authentication_keys = generate_key_pairs()
|
|
140
147
|
|
|
141
148
|
# Always configure auth interceptor, with either user-provided or generated keys
|
|
142
149
|
interceptors: Sequence[grpc.UnaryUnaryClientInterceptor] = [
|
|
143
|
-
|
|
150
|
+
NodeAuthClientInterceptor(*authentication_keys),
|
|
144
151
|
]
|
|
152
|
+
node_pk = public_key_to_bytes(authentication_keys[1])
|
|
145
153
|
channel = create_channel(
|
|
146
154
|
server_address=server_address,
|
|
147
155
|
insecure=insecure,
|
|
@@ -160,7 +168,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
160
168
|
# Wrap stub
|
|
161
169
|
_wrap_stub(stub, retry_invoker)
|
|
162
170
|
###########################################################################
|
|
163
|
-
#
|
|
171
|
+
# SuperNode functions
|
|
164
172
|
###########################################################################
|
|
165
173
|
|
|
166
174
|
def send_node_heartbeat() -> bool:
|
|
@@ -197,22 +205,26 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
197
205
|
|
|
198
206
|
heartbeat_sender = HeartbeatSender(send_node_heartbeat)
|
|
199
207
|
|
|
200
|
-
def
|
|
201
|
-
"""
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
208
|
+
def register_node() -> None:
|
|
209
|
+
"""Register node with SuperLink."""
|
|
210
|
+
stub.RegisterNode(RegisterNodeFleetRequest(public_key=node_pk))
|
|
211
|
+
|
|
212
|
+
def activate_node() -> int:
|
|
213
|
+
"""Activate node and start heartbeat."""
|
|
214
|
+
req = ActivateNodeRequest(
|
|
215
|
+
public_key=node_pk,
|
|
216
|
+
heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
|
|
205
217
|
)
|
|
206
|
-
|
|
218
|
+
res: ActivateNodeResponse = stub.ActivateNode(req)
|
|
207
219
|
|
|
208
220
|
# Remember the node and start the heartbeat sender
|
|
209
221
|
nonlocal node
|
|
210
|
-
node =
|
|
222
|
+
node = Node(node_id=res.node_id)
|
|
211
223
|
heartbeat_sender.start()
|
|
212
224
|
return node.node_id
|
|
213
225
|
|
|
214
|
-
def
|
|
215
|
-
"""
|
|
226
|
+
def deactivate_node() -> None:
|
|
227
|
+
"""Deactivate node and stop heartbeat."""
|
|
216
228
|
# Get Node
|
|
217
229
|
nonlocal node
|
|
218
230
|
if node is None:
|
|
@@ -223,8 +235,20 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
223
235
|
heartbeat_sender.stop()
|
|
224
236
|
|
|
225
237
|
# Call FleetAPI
|
|
226
|
-
|
|
227
|
-
stub.
|
|
238
|
+
req = DeactivateNodeRequest(node_id=node.node_id)
|
|
239
|
+
stub.DeactivateNode(req)
|
|
240
|
+
|
|
241
|
+
def unregister_node() -> None:
|
|
242
|
+
"""Unregister node from SuperLink."""
|
|
243
|
+
# Get Node
|
|
244
|
+
nonlocal node
|
|
245
|
+
if node is None:
|
|
246
|
+
log(ERROR, "Node instance missing")
|
|
247
|
+
return
|
|
248
|
+
|
|
249
|
+
# Call FleetAPI
|
|
250
|
+
req = UnregisterNodeFleetRequest(node_id=node.node_id)
|
|
251
|
+
stub.UnregisterNode(req)
|
|
228
252
|
|
|
229
253
|
# Cleanup
|
|
230
254
|
node = None
|
|
@@ -289,7 +313,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
289
313
|
get_fab_request = GetFabRequest(node=node, hash_str=fab_hash, run_id=run_id)
|
|
290
314
|
get_fab_response: GetFabResponse = stub.GetFab(request=get_fab_request)
|
|
291
315
|
|
|
292
|
-
return
|
|
316
|
+
return fab_from_proto(get_fab_response.fab)
|
|
293
317
|
|
|
294
318
|
def pull_object(run_id: int, object_id: str) -> bytes:
|
|
295
319
|
"""Pull the object from the SuperLink."""
|
|
@@ -331,12 +355,14 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
331
355
|
fn(object_id)
|
|
332
356
|
|
|
333
357
|
try:
|
|
358
|
+
if self_registered:
|
|
359
|
+
register_node()
|
|
360
|
+
node_id = activate_node()
|
|
334
361
|
# Yield methods
|
|
335
362
|
yield (
|
|
363
|
+
node_id,
|
|
336
364
|
receive,
|
|
337
365
|
send,
|
|
338
|
-
create_node,
|
|
339
|
-
delete_node,
|
|
340
366
|
get_run,
|
|
341
367
|
get_fab,
|
|
342
368
|
pull_object,
|
|
@@ -351,7 +377,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
351
377
|
if node is not None:
|
|
352
378
|
# Disable retrying
|
|
353
379
|
retry_invoker.max_tries = 1
|
|
354
|
-
|
|
380
|
+
deactivate_node()
|
|
381
|
+
if self_registered:
|
|
382
|
+
unregister_node()
|
|
355
383
|
except grpc.RpcError:
|
|
356
384
|
pass
|
|
357
385
|
channel.close()
|
|
@@ -34,14 +34,18 @@ from flwr.common.constant import (
|
|
|
34
34
|
from flwr.common.version import package_name, package_version
|
|
35
35
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
36
36
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
37
|
+
ActivateNodeRequest,
|
|
38
|
+
ActivateNodeResponse,
|
|
39
|
+
DeactivateNodeRequest,
|
|
40
|
+
DeactivateNodeResponse,
|
|
41
41
|
PullMessagesRequest,
|
|
42
42
|
PullMessagesResponse,
|
|
43
43
|
PushMessagesRequest,
|
|
44
44
|
PushMessagesResponse,
|
|
45
|
+
RegisterNodeFleetRequest,
|
|
46
|
+
RegisterNodeFleetResponse,
|
|
47
|
+
UnregisterNodeFleetRequest,
|
|
48
|
+
UnregisterNodeFleetResponse,
|
|
45
49
|
)
|
|
46
50
|
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
47
51
|
from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub
|
|
@@ -118,17 +122,29 @@ class GrpcAdapter:
|
|
|
118
122
|
response.ParseFromString(container_res.grpc_message_content)
|
|
119
123
|
return response
|
|
120
124
|
|
|
121
|
-
def
|
|
122
|
-
self, request:
|
|
123
|
-
) ->
|
|
125
|
+
def RegisterNode( # pylint: disable=C0103
|
|
126
|
+
self, request: RegisterNodeFleetRequest, **kwargs: Any
|
|
127
|
+
) -> RegisterNodeFleetResponse:
|
|
124
128
|
"""."""
|
|
125
|
-
return self._send_and_receive(request,
|
|
129
|
+
return self._send_and_receive(request, RegisterNodeFleetResponse, **kwargs)
|
|
126
130
|
|
|
127
|
-
def
|
|
128
|
-
self, request:
|
|
129
|
-
) ->
|
|
131
|
+
def ActivateNode( # pylint: disable=C0103
|
|
132
|
+
self, request: ActivateNodeRequest, **kwargs: Any
|
|
133
|
+
) -> ActivateNodeResponse:
|
|
130
134
|
"""."""
|
|
131
|
-
return self._send_and_receive(request,
|
|
135
|
+
return self._send_and_receive(request, ActivateNodeResponse, **kwargs)
|
|
136
|
+
|
|
137
|
+
def DeactivateNode( # pylint: disable=C0103
|
|
138
|
+
self, request: DeactivateNodeRequest, **kwargs: Any
|
|
139
|
+
) -> DeactivateNodeResponse:
|
|
140
|
+
"""."""
|
|
141
|
+
return self._send_and_receive(request, DeactivateNodeResponse, **kwargs)
|
|
142
|
+
|
|
143
|
+
def UnregisterNode( # pylint: disable=C0103
|
|
144
|
+
self, request: UnregisterNodeFleetRequest, **kwargs: Any
|
|
145
|
+
) -> UnregisterNodeFleetResponse:
|
|
146
|
+
"""."""
|
|
147
|
+
return self._send_and_receive(request, UnregisterNodeFleetResponse, **kwargs)
|
|
132
148
|
|
|
133
149
|
def SendNodeHeartbeat( # pylint: disable=C0103
|
|
134
150
|
self, request: SendNodeHeartbeatRequest, **kwargs: Any
|
|
@@ -23,14 +23,11 @@ from google.protobuf.message import Message as GrpcMessage
|
|
|
23
23
|
|
|
24
24
|
from flwr.common import now
|
|
25
25
|
from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER
|
|
26
|
-
from flwr.
|
|
27
|
-
public_key_to_bytes,
|
|
28
|
-
sign_message,
|
|
29
|
-
)
|
|
26
|
+
from flwr.supercore.primitives.asymmetric import public_key_to_bytes, sign_message
|
|
30
27
|
|
|
31
28
|
|
|
32
|
-
class
|
|
33
|
-
"""Client interceptor for
|
|
29
|
+
class NodeAuthClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
|
|
30
|
+
"""Client interceptor for node authentication."""
|
|
34
31
|
|
|
35
32
|
def __init__(
|
|
36
33
|
self,
|
|
@@ -35,14 +35,9 @@ from flwr.common.constant import MessageType
|
|
|
35
35
|
from flwr.common.logger import log
|
|
36
36
|
from flwr.common.secure_aggregation.crypto.shamir import create_shares
|
|
37
37
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
38
|
-
bytes_to_private_key,
|
|
39
|
-
bytes_to_public_key,
|
|
40
38
|
decrypt,
|
|
41
39
|
encrypt,
|
|
42
|
-
generate_key_pairs,
|
|
43
40
|
generate_shared_key,
|
|
44
|
-
private_key_to_bytes,
|
|
45
|
-
public_key_to_bytes,
|
|
46
41
|
)
|
|
47
42
|
from flwr.common.secure_aggregation.ndarrays_arithmetic import (
|
|
48
43
|
factor_combine,
|
|
@@ -64,6 +59,13 @@ from flwr.common.secure_aggregation.secaggplus_utils import (
|
|
|
64
59
|
share_keys_plaintext_separate,
|
|
65
60
|
)
|
|
66
61
|
from flwr.common.typing import ConfigRecordValues
|
|
62
|
+
from flwr.supercore.primitives.asymmetric import (
|
|
63
|
+
bytes_to_private_key,
|
|
64
|
+
bytes_to_public_key,
|
|
65
|
+
generate_key_pairs,
|
|
66
|
+
private_key_to_bytes,
|
|
67
|
+
public_key_to_bytes,
|
|
68
|
+
)
|
|
67
69
|
|
|
68
70
|
|
|
69
71
|
@dataclass
|