flwr 1.22.0__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. flwr/cli/app.py +15 -1
  2. flwr/cli/auth_plugin/__init__.py +15 -6
  3. flwr/cli/auth_plugin/auth_plugin.py +95 -0
  4. flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
  6. flwr/cli/build.py +118 -47
  7. flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
  8. flwr/cli/log.py +2 -2
  9. flwr/cli/login/login.py +34 -23
  10. flwr/cli/ls.py +13 -9
  11. flwr/cli/new/new.py +187 -35
  12. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  13. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  14. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  15. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  16. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  17. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  18. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  19. flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
  20. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  21. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  22. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
  23. flwr/cli/pull.py +2 -2
  24. flwr/cli/run/run.py +11 -7
  25. flwr/cli/stop.py +2 -2
  26. flwr/cli/supernode/__init__.py +25 -0
  27. flwr/cli/supernode/ls.py +260 -0
  28. flwr/cli/supernode/register.py +185 -0
  29. flwr/cli/supernode/unregister.py +138 -0
  30. flwr/cli/utils.py +92 -69
  31. flwr/client/__init__.py +2 -1
  32. flwr/client/grpc_adapter_client/connection.py +6 -8
  33. flwr/client/grpc_rere_client/connection.py +59 -31
  34. flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
  35. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
  36. flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
  37. flwr/client/rest_client/connection.py +82 -37
  38. flwr/clientapp/__init__.py +1 -2
  39. flwr/{client/clientapp → clientapp}/utils.py +1 -1
  40. flwr/common/constant.py +53 -13
  41. flwr/common/exit/exit_code.py +20 -10
  42. flwr/common/inflatable_utils.py +10 -10
  43. flwr/common/record/array.py +3 -3
  44. flwr/common/record/arrayrecord.py +10 -1
  45. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  46. flwr/common/serde.py +4 -2
  47. flwr/common/typing.py +7 -6
  48. flwr/compat/client/app.py +1 -1
  49. flwr/compat/client/grpc_client/connection.py +2 -2
  50. flwr/proto/control_pb2.py +48 -35
  51. flwr/proto/control_pb2.pyi +71 -5
  52. flwr/proto/control_pb2_grpc.py +102 -0
  53. flwr/proto/control_pb2_grpc.pyi +39 -0
  54. flwr/proto/fab_pb2.py +11 -7
  55. flwr/proto/fab_pb2.pyi +21 -1
  56. flwr/proto/fleet_pb2.py +31 -23
  57. flwr/proto/fleet_pb2.pyi +63 -23
  58. flwr/proto/fleet_pb2_grpc.py +98 -28
  59. flwr/proto/fleet_pb2_grpc.pyi +45 -13
  60. flwr/proto/node_pb2.py +3 -1
  61. flwr/proto/node_pb2.pyi +48 -0
  62. flwr/server/app.py +139 -114
  63. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
  64. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
  65. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
  66. flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
  67. flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
  68. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  69. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
  70. flwr/server/superlink/fleet/vce/vce_api.py +18 -5
  71. flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
  72. flwr/server/superlink/linkstate/linkstate.py +107 -24
  73. flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
  74. flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
  75. flwr/server/superlink/linkstate/utils.py +3 -54
  76. flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
  77. flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
  78. flwr/server/utils/validator.py +2 -3
  79. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
  80. flwr/simulation/ray_transport/ray_actor.py +1 -1
  81. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  82. flwr/simulation/run_simulation.py +3 -2
  83. flwr/supercore/constant.py +22 -0
  84. flwr/supercore/object_store/in_memory_object_store.py +0 -4
  85. flwr/supercore/object_store/object_store_factory.py +26 -6
  86. flwr/supercore/object_store/sqlite_object_store.py +252 -0
  87. flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
  88. flwr/supercore/primitives/asymmetric.py +117 -0
  89. flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
  90. flwr/supercore/sqlite_mixin.py +156 -0
  91. flwr/supercore/utils.py +20 -0
  92. flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
  93. flwr/superlink/auth_plugin/auth_plugin.py +91 -0
  94. flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
  95. flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
  96. flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
  97. flwr/superlink/servicer/control/control_grpc.py +13 -11
  98. flwr/superlink/servicer/control/control_servicer.py +152 -60
  99. flwr/supernode/cli/flower_supernode.py +19 -26
  100. flwr/supernode/runtime/run_clientapp.py +2 -2
  101. flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
  102. flwr/supernode/start_client_internal.py +17 -9
  103. {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/METADATA +1 -1
  104. {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/RECORD +107 -96
  105. flwr/common/auth_plugin/auth_plugin.py +0 -149
  106. /flwr/{client → clientapp}/client_app.py +0 -0
  107. {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
  108. {flwr-1.22.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
flwr/cli/utils.py CHANGED
@@ -26,16 +26,18 @@ from typing import Any, Callable, Optional, Union, cast
26
26
  import grpc
27
27
  import typer
28
28
 
29
- from flwr.cli.cli_user_auth_interceptor import CliUserAuthInterceptor
30
- from flwr.common.auth_plugin import CliAuthPlugin
31
29
  from flwr.common.constant import (
32
- AUTH_TYPE_JSON_KEY,
30
+ AUTHN_TYPE_JSON_KEY,
33
31
  CREDENTIALS_DIR,
34
32
  FLWR_DIR,
33
+ NO_ACCOUNT_AUTH_MESSAGE,
35
34
  NO_ARTIFACT_PROVIDER_MESSAGE,
36
- NO_USER_AUTH_MESSAGE,
35
+ NODE_NOT_FOUND_MESSAGE,
36
+ PUBLIC_KEY_ALREADY_IN_USE_MESSAGE,
37
+ PUBLIC_KEY_NOT_VALID,
37
38
  PULL_UNFINISHED_RUN_MESSAGE,
38
39
  RUN_ID_NOT_FOUND_MESSAGE,
40
+ AuthnType,
39
41
  )
40
42
  from flwr.common.grpc import (
41
43
  GRPC_MAX_MESSAGE_LENGTH,
@@ -43,7 +45,8 @@ from flwr.common.grpc import (
43
45
  on_channel_state_change,
44
46
  )
45
47
 
46
- from .auth_plugin import get_cli_auth_plugins
48
+ from .auth_plugin import CliAuthPlugin, get_cli_plugin_class
49
+ from .cli_account_auth_interceptor import CliAccountAuthInterceptor
47
50
  from .config_utils import validate_certificate_in_federation_config
48
51
 
49
52
 
@@ -166,8 +169,8 @@ def get_sha256_hash(file_path_or_int: Union[Path, int]) -> str:
166
169
  return sha256.hexdigest()
167
170
 
168
171
 
169
- def get_user_auth_config_path(root_dir: Path, federation: str) -> Path:
170
- """Return the path to the user auth config file.
172
+ def get_account_auth_config_path(root_dir: Path, federation: str) -> Path:
173
+ """Return the path to the account auth config file.
171
174
 
172
175
  Additionally, a `.gitignore` file will be created in the Flower directory to
173
176
  include the `.credentials` folder to be excluded from git. If the `.gitignore`
@@ -217,71 +220,68 @@ def get_user_auth_config_path(root_dir: Path, federation: str) -> Path:
217
220
  return credentials_dir / f"{federation}.json"
218
221
 
219
222
 
220
- def try_obtain_cli_auth_plugin(
223
+ def account_auth_enabled(federation_config: dict[str, Any]) -> bool:
224
+ """Check if account authentication is enabled in the federation config."""
225
+ enabled: bool = federation_config.get("enable-user-auth", False)
226
+ enabled |= federation_config.get("enable-account-auth", False)
227
+ if "enable-user-auth" in federation_config:
228
+ typer.secho(
229
+ "`enable-user-auth` is deprecated and will be removed in a future "
230
+ "release. Please use `enable-account-auth` instead.",
231
+ fg=typer.colors.YELLOW,
232
+ bold=True,
233
+ )
234
+ return enabled
235
+
236
+
237
+ def retrieve_authn_type(config_path: Path) -> str:
238
+ """Retrieve the auth type from the config file or return NOOP if not found."""
239
+ try:
240
+ with config_path.open("r", encoding="utf-8") as file:
241
+ json_file = json.load(file)
242
+ authn_type: str = json_file[AUTHN_TYPE_JSON_KEY]
243
+ return authn_type
244
+ except (FileNotFoundError, KeyError):
245
+ return AuthnType.NOOP
246
+
247
+
248
+ def load_cli_auth_plugin(
221
249
  root_dir: Path,
222
250
  federation: str,
223
251
  federation_config: dict[str, Any],
224
- auth_type: Optional[str] = None,
225
- ) -> Optional[CliAuthPlugin]:
226
- """Load the CLI-side user auth plugin for the given auth type."""
227
- # Check if user auth is enabled
228
- if not federation_config.get("enable-user-auth", False):
229
- return None
230
-
231
- config_path = get_user_auth_config_path(root_dir, federation)
232
-
233
- # Get the auth type from the config if not provided
234
- # auth_type will be None for all CLI commands except login
235
- if auth_type is None:
236
- try:
237
- with config_path.open("r", encoding="utf-8") as file:
238
- json_file = json.load(file)
239
- auth_type = json_file[AUTH_TYPE_JSON_KEY]
240
- except (FileNotFoundError, KeyError):
241
- typer.secho(
242
- "❌ Missing or invalid credentials for user authentication. "
243
- "Please run `flwr login` to authenticate.",
244
- fg=typer.colors.RED,
245
- bold=True,
246
- )
247
- raise typer.Exit(code=1) from None
252
+ authn_type: Optional[str] = None,
253
+ ) -> CliAuthPlugin:
254
+ """Load the CLI-side account auth plugin for the given authn type."""
255
+ # Find the path to the account auth config file
256
+ config_path = get_account_auth_config_path(root_dir, federation)
257
+
258
+ # Determine the auth type if not provided
259
+ # Only `flwr login` command can provide `authn_type` explicitly, as it can query the
260
+ # SuperLink for the auth type.
261
+ if authn_type is None:
262
+ authn_type = AuthnType.NOOP
263
+ if account_auth_enabled(federation_config):
264
+ authn_type = retrieve_authn_type(config_path)
248
265
 
249
266
  # Retrieve auth plugin class and instantiate it
250
267
  try:
251
- all_plugins: dict[str, type[CliAuthPlugin]] = get_cli_auth_plugins()
252
- auth_plugin_class = all_plugins[auth_type]
268
+ auth_plugin_class = get_cli_plugin_class(authn_type)
253
269
  return auth_plugin_class(config_path)
254
- except KeyError:
255
- typer.echo(f"❌ Unknown user authentication type: {auth_type}")
256
- raise typer.Exit(code=1) from None
257
- except ImportError:
258
- typer.echo("❌ No authentication plugins are currently supported.")
270
+ except ValueError:
271
+ typer.echo(f"❌ Unknown account authentication type: {authn_type}")
259
272
  raise typer.Exit(code=1) from None
260
273
 
261
274
 
262
275
  def init_channel(
263
- app: Path, federation_config: dict[str, Any], auth_plugin: Optional[CliAuthPlugin]
276
+ app: Path, federation_config: dict[str, Any], auth_plugin: CliAuthPlugin
264
277
  ) -> grpc.Channel:
265
278
  """Initialize gRPC channel to the Control API."""
266
279
  insecure, root_certificates_bytes = validate_certificate_in_federation_config(
267
280
  app, federation_config
268
281
  )
269
282
 
270
- # Initialize the CLI-side user auth interceptor
271
- interceptors: list[grpc.UnaryUnaryClientInterceptor] = []
272
- if auth_plugin is not None:
273
- # Check if TLS is enabled. If not, raise an error
274
- if insecure:
275
- typer.secho(
276
- "❌ User authentication requires TLS to be enabled. "
277
- "Remove `insecure = true` from the federation configuration.",
278
- fg=typer.colors.RED,
279
- bold=True,
280
- )
281
- raise typer.Exit(code=1)
282
-
283
- auth_plugin.load_tokens()
284
- interceptors.append(CliUserAuthInterceptor(auth_plugin))
283
+ # Load tokens
284
+ auth_plugin.load_tokens()
285
285
 
286
286
  # Create the gRPC channel
287
287
  channel = create_channel(
@@ -289,14 +289,14 @@ def init_channel(
289
289
  insecure=insecure,
290
290
  root_certificates=root_certificates_bytes,
291
291
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
292
- interceptors=interceptors or None,
292
+ interceptors=[CliAccountAuthInterceptor(auth_plugin)],
293
293
  )
294
294
  channel.subscribe(on_channel_state_change)
295
295
  return channel
296
296
 
297
297
 
298
298
  @contextmanager
299
- def flwr_cli_grpc_exc_handler() -> Iterator[None]:
299
+ def flwr_cli_grpc_exc_handler() -> Iterator[None]: # pylint: disable=too-many-branches
300
300
  """Context manager to handle specific gRPC errors.
301
301
 
302
302
  It catches grpc.RpcError exceptions with UNAUTHENTICATED, UNIMPLEMENTED,
@@ -315,9 +315,9 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
315
315
  )
316
316
  raise typer.Exit(code=1) from None
317
317
  if e.code() == grpc.StatusCode.UNIMPLEMENTED:
318
- if e.details() == NO_USER_AUTH_MESSAGE: # pylint: disable=E1101
318
+ if e.details() == NO_ACCOUNT_AUTH_MESSAGE: # pylint: disable=E1101
319
319
  typer.secho(
320
- "❌ User authentication is not enabled on this SuperLink.",
320
+ "❌ Account authentication is not enabled on this SuperLink.",
321
321
  fg=typer.colors.RED,
322
322
  bold=True,
323
323
  )
@@ -354,16 +354,21 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
354
354
  bold=True,
355
355
  )
356
356
  raise typer.Exit(code=1) from None
357
- if (
358
- e.code() == grpc.StatusCode.NOT_FOUND
359
- and e.details() == RUN_ID_NOT_FOUND_MESSAGE # pylint: disable=E1101
360
- ):
361
- typer.secho(
362
- "❌ Run ID not found.",
363
- fg=typer.colors.RED,
364
- bold=True,
365
- )
366
- raise typer.Exit(code=1) from None
357
+ if e.code() == grpc.StatusCode.NOT_FOUND:
358
+ if e.details() == RUN_ID_NOT_FOUND_MESSAGE: # pylint: disable=E1101
359
+ typer.secho(
360
+ "❌ Run ID not found.",
361
+ fg=typer.colors.RED,
362
+ bold=True,
363
+ )
364
+ raise typer.Exit(code=1) from None
365
+ if e.details() == NODE_NOT_FOUND_MESSAGE: # pylint: disable=E1101
366
+ typer.secho(
367
+ "❌ Node ID not found for this account.",
368
+ fg=typer.colors.RED,
369
+ bold=True,
370
+ )
371
+ raise typer.Exit(code=1) from None
367
372
  if e.code() == grpc.StatusCode.FAILED_PRECONDITION:
368
373
  if e.details() == PULL_UNFINISHED_RUN_MESSAGE: # pylint: disable=E1101
369
374
  typer.secho(
@@ -373,4 +378,22 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
373
378
  bold=True,
374
379
  )
375
380
  raise typer.Exit(code=1) from None
381
+ if (
382
+ e.details() == PUBLIC_KEY_ALREADY_IN_USE_MESSAGE
383
+ ): # pylint: disable=E1101
384
+ typer.secho(
385
+ "❌ The provided public key is already in use by another "
386
+ "SuperNode.",
387
+ fg=typer.colors.RED,
388
+ bold=True,
389
+ )
390
+ raise typer.Exit(code=1) from None
391
+ if e.details() == PUBLIC_KEY_NOT_VALID: # pylint: disable=E1101
392
+ typer.secho(
393
+ "❌ The provided public key is invalid. Please provide a valid "
394
+ "NIST EC public key.",
395
+ fg=typer.colors.RED,
396
+ bold=True,
397
+ )
398
+ raise typer.Exit(code=1) from None
376
399
  raise
flwr/client/__init__.py CHANGED
@@ -15,10 +15,11 @@
15
15
  """Flower client."""
16
16
 
17
17
 
18
+ from flwr.clientapp import ClientApp
19
+
18
20
  from ..compat.client.app import start_client as start_client # Deprecated
19
21
  from ..compat.client.app import start_numpy_client as start_numpy_client # Deprecated
20
22
  from .client import Client as Client
21
- from .client_app import ClientApp as ClientApp
22
23
  from .numpy_client import NumPyClient as NumPyClient
23
24
  from .typing import ClientFn as ClientFn
24
25
  from .typing import ClientFnExt as ClientFnExt
@@ -44,10 +44,9 @@ def grpc_adapter( # pylint: disable=R0913,too-many-positional-arguments
44
44
  ] = None,
45
45
  ) -> Iterator[
46
46
  tuple[
47
+ int,
47
48
  Callable[[], Optional[tuple[Message, ObjectTree]]],
48
49
  Callable[[Message, ObjectTree], set[str]],
49
- Callable[[], Optional[int]],
50
- Callable[[], None],
51
50
  Callable[[int], Run],
52
51
  Callable[[str, int], Fab],
53
52
  Callable[[int, str], bytes],
@@ -77,22 +76,21 @@ def grpc_adapter( # pylint: disable=R0913,too-many-positional-arguments
77
76
  connection using the certificates will be established to an SSL-enabled
78
77
  Flower server. Bytes won't work for the REST API.
79
78
  authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
80
- Client authentication is not supported for this transport type.
79
+ SuperNode authentication is not supported for this transport type.
81
80
 
82
81
  Returns
83
82
  -------
83
+ node_id : int
84
84
  receive : Callable[[], Optional[tuple[Message, ObjectTree]]]
85
85
  send : Callable[[Message, ObjectTree], set[str]]
86
- create_node : Optional[Callable]
87
- delete_node : Optional[Callable]
88
- get_run : Optional[Callable]
89
- get_fab : Optional[Callable]
86
+ get_run : Callable[[int], Run]
87
+ get_fab : Callable[[str, int], Fab]
90
88
  pull_object : Callable[[str], bytes]
91
89
  push_object : Callable[[str, bytes], None]
92
90
  confirm_message_received : Callable[[str], None]
93
91
  """
94
92
  if authentication_keys is not None:
95
- log(ERROR, "Client authentication is not supported for this transport type.")
93
+ log(ERROR, "SuperNode authentication is not supported for this transport type.")
96
94
  with grpc_request_response(
97
95
  server_address=server_address,
98
96
  insecure=insecure,
@@ -19,7 +19,7 @@ from collections.abc import Iterator, Sequence
19
19
  from contextlib import contextmanager
20
20
  from logging import ERROR
21
21
  from pathlib import Path
22
- from typing import Callable, Optional, Union, cast
22
+ from typing import Callable, Optional, Union
23
23
 
24
24
  import grpc
25
25
  from cryptography.hazmat.primitives.asymmetric import ec
@@ -36,19 +36,24 @@ from flwr.common.inflatable_protobuf_utils import (
36
36
  from flwr.common.logger import log
37
37
  from flwr.common.message import Message, remove_content_from_message
38
38
  from flwr.common.retry_invoker import RetryInvoker, _wrap_stub
39
- from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
40
- generate_key_pairs,
39
+ from flwr.common.serde import (
40
+ fab_from_proto,
41
+ message_from_proto,
42
+ message_to_proto,
43
+ run_from_proto,
41
44
  )
42
- from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
43
45
  from flwr.common.typing import Fab, Run
44
46
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
45
47
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
46
- CreateNodeRequest,
47
- DeleteNodeRequest,
48
+ ActivateNodeRequest,
49
+ ActivateNodeResponse,
50
+ DeactivateNodeRequest,
48
51
  PullMessagesRequest,
49
52
  PullMessagesResponse,
50
53
  PushMessagesRequest,
51
54
  PushMessagesResponse,
55
+ RegisterNodeFleetRequest,
56
+ UnregisterNodeFleetRequest,
52
57
  )
53
58
  from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
54
59
  from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
@@ -58,9 +63,10 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
58
63
  from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
59
64
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
60
65
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
66
+ from flwr.supercore.primitives.asymmetric import generate_key_pairs, public_key_to_bytes
61
67
 
62
- from .client_interceptor import AuthenticateClientInterceptor
63
68
  from .grpc_adapter import GrpcAdapter
69
+ from .node_auth_client_interceptor import NodeAuthClientInterceptor
64
70
 
65
71
 
66
72
  @contextmanager
@@ -76,10 +82,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
76
82
  adapter_cls: Optional[Union[type[FleetStub], type[GrpcAdapter]]] = None,
77
83
  ) -> Iterator[
78
84
  tuple[
85
+ int,
79
86
  Callable[[], Optional[tuple[Message, ObjectTree]]],
80
87
  Callable[[Message, ObjectTree], set[str]],
81
- Callable[[], Optional[int]],
82
- Callable[[], None],
83
88
  Callable[[int], Run],
84
89
  Callable[[str, int], Fab],
85
90
  Callable[[int, str], bytes],
@@ -122,11 +127,11 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
122
127
 
123
128
  Returns
124
129
  -------
125
- receive : Callable
126
- send : Callable
127
- create_node : Optional[Callable]
128
- delete_node : Optional[Callable]
129
- get_run : Optional[Callable]
130
+ node_id : int
131
+ receive : Callable[[], Optional[tuple[Message, ObjectTree]]]
132
+ send : Callable[[Message, ObjectTree], set[str]]
133
+ get_run : Callable[[int], Run]
134
+ get_fab : Callable[[str, int], Fab]
130
135
  pull_object : Callable[[str], bytes]
131
136
  push_object : Callable[[str, bytes], None]
132
137
  confirm_message_received : Callable[[str], None]
@@ -135,13 +140,16 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
135
140
  root_certificates = Path(root_certificates).read_bytes()
136
141
 
137
142
  # Automatic node auth: generate keys if user didn't provide any
143
+ self_registered = False
138
144
  if authentication_keys is None:
145
+ self_registered = True
139
146
  authentication_keys = generate_key_pairs()
140
147
 
141
148
  # Always configure auth interceptor, with either user-provided or generated keys
142
149
  interceptors: Sequence[grpc.UnaryUnaryClientInterceptor] = [
143
- AuthenticateClientInterceptor(*authentication_keys),
150
+ NodeAuthClientInterceptor(*authentication_keys),
144
151
  ]
152
+ node_pk = public_key_to_bytes(authentication_keys[1])
145
153
  channel = create_channel(
146
154
  server_address=server_address,
147
155
  insecure=insecure,
@@ -160,7 +168,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
160
168
  # Wrap stub
161
169
  _wrap_stub(stub, retry_invoker)
162
170
  ###########################################################################
163
- # send_node_heartbeat/create_node/delete_node/receive/send/get_run functions
171
+ # SuperNode functions
164
172
  ###########################################################################
165
173
 
166
174
  def send_node_heartbeat() -> bool:
@@ -197,22 +205,26 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
197
205
 
198
206
  heartbeat_sender = HeartbeatSender(send_node_heartbeat)
199
207
 
200
- def create_node() -> Optional[int]:
201
- """Set create_node."""
202
- # Call FleetAPI
203
- create_node_request = CreateNodeRequest(
204
- heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
208
+ def register_node() -> None:
209
+ """Register node with SuperLink."""
210
+ stub.RegisterNode(RegisterNodeFleetRequest(public_key=node_pk))
211
+
212
+ def activate_node() -> int:
213
+ """Activate node and start heartbeat."""
214
+ req = ActivateNodeRequest(
215
+ public_key=node_pk,
216
+ heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
205
217
  )
206
- create_node_response = stub.CreateNode(request=create_node_request)
218
+ res: ActivateNodeResponse = stub.ActivateNode(req)
207
219
 
208
220
  # Remember the node and start the heartbeat sender
209
221
  nonlocal node
210
- node = cast(Node, create_node_response.node)
222
+ node = Node(node_id=res.node_id)
211
223
  heartbeat_sender.start()
212
224
  return node.node_id
213
225
 
214
- def delete_node() -> None:
215
- """Set delete_node."""
226
+ def deactivate_node() -> None:
227
+ """Deactivate node and stop heartbeat."""
216
228
  # Get Node
217
229
  nonlocal node
218
230
  if node is None:
@@ -223,8 +235,20 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
223
235
  heartbeat_sender.stop()
224
236
 
225
237
  # Call FleetAPI
226
- delete_node_request = DeleteNodeRequest(node=node)
227
- stub.DeleteNode(request=delete_node_request)
238
+ req = DeactivateNodeRequest(node_id=node.node_id)
239
+ stub.DeactivateNode(req)
240
+
241
+ def unregister_node() -> None:
242
+ """Unregister node from SuperLink."""
243
+ # Get Node
244
+ nonlocal node
245
+ if node is None:
246
+ log(ERROR, "Node instance missing")
247
+ return
248
+
249
+ # Call FleetAPI
250
+ req = UnregisterNodeFleetRequest(node_id=node.node_id)
251
+ stub.UnregisterNode(req)
228
252
 
229
253
  # Cleanup
230
254
  node = None
@@ -289,7 +313,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
289
313
  get_fab_request = GetFabRequest(node=node, hash_str=fab_hash, run_id=run_id)
290
314
  get_fab_response: GetFabResponse = stub.GetFab(request=get_fab_request)
291
315
 
292
- return Fab(get_fab_response.fab.hash_str, get_fab_response.fab.content)
316
+ return fab_from_proto(get_fab_response.fab)
293
317
 
294
318
  def pull_object(run_id: int, object_id: str) -> bytes:
295
319
  """Pull the object from the SuperLink."""
@@ -331,12 +355,14 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
331
355
  fn(object_id)
332
356
 
333
357
  try:
358
+ if self_registered:
359
+ register_node()
360
+ node_id = activate_node()
334
361
  # Yield methods
335
362
  yield (
363
+ node_id,
336
364
  receive,
337
365
  send,
338
- create_node,
339
- delete_node,
340
366
  get_run,
341
367
  get_fab,
342
368
  pull_object,
@@ -351,7 +377,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
351
377
  if node is not None:
352
378
  # Disable retrying
353
379
  retry_invoker.max_tries = 1
354
- delete_node()
380
+ deactivate_node()
381
+ if self_registered:
382
+ unregister_node()
355
383
  except grpc.RpcError:
356
384
  pass
357
385
  channel.close()
@@ -34,14 +34,18 @@ from flwr.common.constant import (
34
34
  from flwr.common.version import package_name, package_version
35
35
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
36
36
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
37
- CreateNodeRequest,
38
- CreateNodeResponse,
39
- DeleteNodeRequest,
40
- DeleteNodeResponse,
37
+ ActivateNodeRequest,
38
+ ActivateNodeResponse,
39
+ DeactivateNodeRequest,
40
+ DeactivateNodeResponse,
41
41
  PullMessagesRequest,
42
42
  PullMessagesResponse,
43
43
  PushMessagesRequest,
44
44
  PushMessagesResponse,
45
+ RegisterNodeFleetRequest,
46
+ RegisterNodeFleetResponse,
47
+ UnregisterNodeFleetRequest,
48
+ UnregisterNodeFleetResponse,
45
49
  )
46
50
  from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
47
51
  from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub
@@ -118,17 +122,29 @@ class GrpcAdapter:
118
122
  response.ParseFromString(container_res.grpc_message_content)
119
123
  return response
120
124
 
121
- def CreateNode( # pylint: disable=C0103
122
- self, request: CreateNodeRequest, **kwargs: Any
123
- ) -> CreateNodeResponse:
125
+ def RegisterNode( # pylint: disable=C0103
126
+ self, request: RegisterNodeFleetRequest, **kwargs: Any
127
+ ) -> RegisterNodeFleetResponse:
124
128
  """."""
125
- return self._send_and_receive(request, CreateNodeResponse, **kwargs)
129
+ return self._send_and_receive(request, RegisterNodeFleetResponse, **kwargs)
126
130
 
127
- def DeleteNode( # pylint: disable=C0103
128
- self, request: DeleteNodeRequest, **kwargs: Any
129
- ) -> DeleteNodeResponse:
131
+ def ActivateNode( # pylint: disable=C0103
132
+ self, request: ActivateNodeRequest, **kwargs: Any
133
+ ) -> ActivateNodeResponse:
130
134
  """."""
131
- return self._send_and_receive(request, DeleteNodeResponse, **kwargs)
135
+ return self._send_and_receive(request, ActivateNodeResponse, **kwargs)
136
+
137
+ def DeactivateNode( # pylint: disable=C0103
138
+ self, request: DeactivateNodeRequest, **kwargs: Any
139
+ ) -> DeactivateNodeResponse:
140
+ """."""
141
+ return self._send_and_receive(request, DeactivateNodeResponse, **kwargs)
142
+
143
+ def UnregisterNode( # pylint: disable=C0103
144
+ self, request: UnregisterNodeFleetRequest, **kwargs: Any
145
+ ) -> UnregisterNodeFleetResponse:
146
+ """."""
147
+ return self._send_and_receive(request, UnregisterNodeFleetResponse, **kwargs)
132
148
 
133
149
  def SendNodeHeartbeat( # pylint: disable=C0103
134
150
  self, request: SendNodeHeartbeatRequest, **kwargs: Any
@@ -23,14 +23,11 @@ from google.protobuf.message import Message as GrpcMessage
23
23
 
24
24
  from flwr.common import now
25
25
  from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER
26
- from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
27
- public_key_to_bytes,
28
- sign_message,
29
- )
26
+ from flwr.supercore.primitives.asymmetric import public_key_to_bytes, sign_message
30
27
 
31
28
 
32
- class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
33
- """Client interceptor for client authentication."""
29
+ class NodeAuthClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
30
+ """Client interceptor for node authentication."""
34
31
 
35
32
  def __init__(
36
33
  self,
@@ -35,14 +35,9 @@ from flwr.common.constant import MessageType
35
35
  from flwr.common.logger import log
36
36
  from flwr.common.secure_aggregation.crypto.shamir import create_shares
37
37
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
38
- bytes_to_private_key,
39
- bytes_to_public_key,
40
38
  decrypt,
41
39
  encrypt,
42
- generate_key_pairs,
43
40
  generate_shared_key,
44
- private_key_to_bytes,
45
- public_key_to_bytes,
46
41
  )
47
42
  from flwr.common.secure_aggregation.ndarrays_arithmetic import (
48
43
  factor_combine,
@@ -64,6 +59,13 @@ from flwr.common.secure_aggregation.secaggplus_utils import (
64
59
  share_keys_plaintext_separate,
65
60
  )
66
61
  from flwr.common.typing import ConfigRecordValues
62
+ from flwr.supercore.primitives.asymmetric import (
63
+ bytes_to_private_key,
64
+ bytes_to_public_key,
65
+ generate_key_pairs,
66
+ private_key_to_bytes,
67
+ public_key_to_bytes,
68
+ )
67
69
 
68
70
 
69
71
  @dataclass