flwr-nightly 1.23.0.dev20251007__py3-none-any.whl → 1.23.0.dev20251008__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. flwr/cli/auth_plugin/__init__.py +7 -3
  2. flwr/cli/log.py +2 -2
  3. flwr/cli/login/login.py +4 -13
  4. flwr/cli/ls.py +2 -2
  5. flwr/cli/pull.py +2 -2
  6. flwr/cli/run/run.py +2 -2
  7. flwr/cli/stop.py +2 -2
  8. flwr/cli/supernode/ls.py +2 -2
  9. flwr/cli/utils.py +28 -44
  10. flwr/client/grpc_rere_client/connection.py +6 -4
  11. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +2 -2
  12. flwr/client/rest_client/connection.py +7 -1
  13. flwr/common/constant.py +10 -0
  14. flwr/proto/fleet_pb2.py +22 -22
  15. flwr/proto/fleet_pb2.pyi +4 -1
  16. flwr/proto/node_pb2.py +1 -1
  17. flwr/server/app.py +32 -31
  18. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +8 -4
  19. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +18 -37
  20. flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
  21. flwr/server/superlink/fleet/vce/vce_api.py +7 -1
  22. flwr/server/superlink/linkstate/in_memory_linkstate.py +39 -27
  23. flwr/server/superlink/linkstate/linkstate.py +1 -1
  24. flwr/server/superlink/linkstate/sqlite_linkstate.py +37 -21
  25. flwr/server/utils/validator.py +2 -3
  26. flwr/superlink/auth_plugin/__init__.py +29 -0
  27. flwr/superlink/servicer/control/control_grpc.py +9 -7
  28. flwr/superlink/servicer/control/control_servicer.py +34 -46
  29. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/METADATA +1 -1
  30. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/RECORD +32 -32
  31. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/WHEEL +0 -0
  32. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251008.dist-info}/entry_points.txt +0 -0
@@ -22,9 +22,13 @@ from .noop_auth_plugin import NoOpCliAuthPlugin
22
22
  from .oidc_cli_plugin import OidcCliPlugin
23
23
 
24
24
 
25
- def get_cli_auth_plugins() -> dict[str, type[CliAuthPlugin]]:
25
+ def get_cli_plugin_class(authn_type: str) -> type[CliAuthPlugin]:
26
26
  """Return all CLI authentication plugins."""
27
- return {AuthnType.NOOP: NoOpCliAuthPlugin, AuthnType.OIDC: OidcCliPlugin}
27
+ if authn_type == AuthnType.NOOP:
28
+ return NoOpCliAuthPlugin
29
+ if authn_type == AuthnType.OIDC:
30
+ return OidcCliPlugin
31
+ raise ValueError(f"Unsupported authentication type: {authn_type}")
28
32
 
29
33
 
30
34
  __all__ = [
@@ -32,5 +36,5 @@ __all__ = [
32
36
  "LoginError",
33
37
  "NoOpCliAuthPlugin",
34
38
  "OidcCliPlugin",
35
- "get_cli_auth_plugins",
39
+ "get_cli_plugin_class",
36
40
  ]
flwr/cli/log.py CHANGED
@@ -35,7 +35,7 @@ from flwr.common.logger import log as logger
35
35
  from flwr.proto.control_pb2 import StreamLogsRequest # pylint: disable=E0611
36
36
  from flwr.proto.control_pb2_grpc import ControlStub
37
37
 
38
- from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
38
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
39
39
 
40
40
 
41
41
  class AllLogsRetrieved(BaseException):
@@ -186,7 +186,7 @@ def _log_with_control_api(
186
186
  run_id: int,
187
187
  stream: bool,
188
188
  ) -> None:
189
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
189
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
190
190
  channel = init_channel(app, federation_config, auth_plugin)
191
191
 
192
192
  if stream:
flwr/cli/login/login.py CHANGED
@@ -20,7 +20,7 @@ from typing import Annotated, Optional
20
20
 
21
21
  import typer
22
22
 
23
- from flwr.cli.auth_plugin import LoginError
23
+ from flwr.cli.auth_plugin import LoginError, NoOpCliAuthPlugin
24
24
  from flwr.cli.config_utils import (
25
25
  exit_if_no_address,
26
26
  get_insecure_flag,
@@ -40,7 +40,7 @@ from ..utils import (
40
40
  account_auth_enabled,
41
41
  flwr_cli_grpc_exc_handler,
42
42
  init_channel,
43
- try_obtain_cli_auth_plugin,
43
+ load_cli_auth_plugin,
44
44
  )
45
45
 
46
46
 
@@ -95,7 +95,7 @@ def login( # pylint: disable=R0914
95
95
  )
96
96
  raise typer.Exit(code=1)
97
97
 
98
- channel = init_channel(app, federation_config, None)
98
+ channel = init_channel(app, federation_config, NoOpCliAuthPlugin(Path()))
99
99
  stub = ControlStub(channel)
100
100
 
101
101
  login_request = GetLoginDetailsRequest()
@@ -104,16 +104,7 @@ def login( # pylint: disable=R0914
104
104
 
105
105
  # Get the auth plugin
106
106
  authn_type = login_response.authn_type
107
- auth_plugin = try_obtain_cli_auth_plugin(
108
- app, federation, federation_config, authn_type
109
- )
110
- if auth_plugin is None:
111
- typer.secho(
112
- f'❌ Authentication type "{authn_type}" not found',
113
- fg=typer.colors.RED,
114
- bold=True,
115
- )
116
- raise typer.Exit(code=1)
107
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config, authn_type)
117
108
 
118
109
  # Login
119
110
  details = AccountAuthLoginDetails(
flwr/cli/ls.py CHANGED
@@ -44,7 +44,7 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
44
44
  )
45
45
  from flwr.proto.control_pb2_grpc import ControlStub
46
46
 
47
- from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
47
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
48
48
 
49
49
  _RunListType = tuple[int, str, str, str, str, str, str, str, str]
50
50
 
@@ -127,7 +127,7 @@ def ls( # pylint: disable=too-many-locals, too-many-branches, R0913, R0917
127
127
  raise ValueError(
128
128
  "The options '--runs' and '--run-id' are mutually exclusive."
129
129
  )
130
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
130
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
131
131
  channel = init_channel(app, federation_config, auth_plugin)
132
132
  stub = ControlStub(channel)
133
133
 
flwr/cli/pull.py CHANGED
@@ -34,7 +34,7 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
34
34
  )
35
35
  from flwr.proto.control_pb2_grpc import ControlStub
36
36
 
37
- from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
37
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
38
38
 
39
39
 
40
40
  def pull( # pylint: disable=R0914
@@ -74,7 +74,7 @@ def pull( # pylint: disable=R0914
74
74
  channel = None
75
75
  try:
76
76
 
77
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
77
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
78
78
  channel = init_channel(app, federation_config, auth_plugin)
79
79
  stub = ControlStub(channel)
80
80
  with flwr_cli_grpc_exc_handler():
flwr/cli/run/run.py CHANGED
@@ -45,7 +45,7 @@ from flwr.proto.control_pb2 import StartRunRequest # pylint: disable=E0611
45
45
  from flwr.proto.control_pb2_grpc import ControlStub
46
46
 
47
47
  from ..log import start_stream
48
- from ..utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
48
+ from ..utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
49
49
 
50
50
  CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)
51
51
 
@@ -148,7 +148,7 @@ def _run_with_control_api(
148
148
  ) -> None:
149
149
  channel = None
150
150
  try:
151
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
151
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
152
152
  channel = init_channel(app, federation_config, auth_plugin)
153
153
  stub = ControlStub(channel)
154
154
 
flwr/cli/stop.py CHANGED
@@ -38,7 +38,7 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
38
38
  )
39
39
  from flwr.proto.control_pb2_grpc import ControlStub
40
40
 
41
- from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
41
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
42
42
 
43
43
 
44
44
  def stop( # pylint: disable=R0914
@@ -89,7 +89,7 @@ def stop( # pylint: disable=R0914
89
89
  exit_if_no_address(federation_config, "stop")
90
90
  channel = None
91
91
  try:
92
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
92
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
93
93
  channel = init_channel(app, federation_config, auth_plugin)
94
94
  stub = ControlStub(channel) # pylint: disable=unused-variable # noqa: F841
95
95
 
flwr/cli/supernode/ls.py CHANGED
@@ -42,7 +42,7 @@ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
42
42
  from flwr.proto.control_pb2_grpc import ControlStub
43
43
  from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
44
44
 
45
- from ..utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
45
+ from ..utils import flwr_cli_grpc_exc_handler, init_channel, load_cli_auth_plugin
46
46
 
47
47
  _NodeListType = tuple[int, str, str, str, str, str, str, str]
48
48
 
@@ -94,7 +94,7 @@ def ls( # pylint: disable=R0914
94
94
  exit_if_no_address(federation_config, f"supernode {command_name}")
95
95
  channel = None
96
96
  try:
97
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
97
+ auth_plugin = load_cli_auth_plugin(app, federation, federation_config)
98
98
  channel = init_channel(app, federation_config, auth_plugin)
99
99
  stub = ControlStub(channel)
100
100
  typer.echo("📄 Listing all nodes...")
flwr/cli/utils.py CHANGED
@@ -34,6 +34,7 @@ from flwr.common.constant import (
34
34
  NO_ARTIFACT_PROVIDER_MESSAGE,
35
35
  PULL_UNFINISHED_RUN_MESSAGE,
36
36
  RUN_ID_NOT_FOUND_MESSAGE,
37
+ AuthnType,
37
38
  )
38
39
  from flwr.common.grpc import (
39
40
  GRPC_MAX_MESSAGE_LENGTH,
@@ -41,7 +42,7 @@ from flwr.common.grpc import (
41
42
  on_channel_state_change,
42
43
  )
43
44
 
44
- from .auth_plugin import CliAuthPlugin, get_cli_auth_plugins
45
+ from .auth_plugin import CliAuthPlugin, get_cli_plugin_class
45
46
  from .cli_account_auth_interceptor import CliAccountAuthInterceptor
46
47
  from .config_utils import validate_certificate_in_federation_config
47
48
 
@@ -230,71 +231,54 @@ def account_auth_enabled(federation_config: dict[str, Any]) -> bool:
230
231
  return enabled
231
232
 
232
233
 
233
- def try_obtain_cli_auth_plugin(
234
+ def retrieve_authn_type(config_path: Path) -> str:
235
+ """Retrieve the auth type from the config file or return NOOP if not found."""
236
+ try:
237
+ with config_path.open("r", encoding="utf-8") as file:
238
+ json_file = json.load(file)
239
+ authn_type: str = json_file[AUTHN_TYPE_JSON_KEY]
240
+ return authn_type
241
+ except (FileNotFoundError, KeyError):
242
+ return AuthnType.NOOP
243
+
244
+
245
+ def load_cli_auth_plugin(
234
246
  root_dir: Path,
235
247
  federation: str,
236
248
  federation_config: dict[str, Any],
237
249
  authn_type: Optional[str] = None,
238
- ) -> Optional[CliAuthPlugin]:
250
+ ) -> CliAuthPlugin:
239
251
  """Load the CLI-side account auth plugin for the given authn type."""
240
- # Check if account auth is enabled
241
- if not account_auth_enabled(federation_config):
242
- return None
243
-
252
+ # Find the path to the account auth config file
244
253
  config_path = get_account_auth_config_path(root_dir, federation)
245
254
 
246
- # Get the authn type from the config if not provided
247
- # authn_type will be None for all CLI commands except login
255
+ # Determine the auth type if not provided
256
+ # Only `flwr login` command can provide `authn_type` explicitly, as it can query the
257
+ # SuperLink for the auth type.
248
258
  if authn_type is None:
249
- try:
250
- with config_path.open("r", encoding="utf-8") as file:
251
- json_file = json.load(file)
252
- authn_type = json_file[AUTHN_TYPE_JSON_KEY]
253
- except (FileNotFoundError, KeyError):
254
- typer.secho(
255
- "❌ Missing or invalid credentials for account authentication. "
256
- "Please run `flwr login` to authenticate.",
257
- fg=typer.colors.RED,
258
- bold=True,
259
- )
260
- raise typer.Exit(code=1) from None
259
+ authn_type = AuthnType.NOOP
260
+ if account_auth_enabled(federation_config):
261
+ authn_type = retrieve_authn_type(config_path)
261
262
 
262
263
  # Retrieve auth plugin class and instantiate it
263
264
  try:
264
- all_plugins: dict[str, type[CliAuthPlugin]] = get_cli_auth_plugins()
265
- auth_plugin_class = all_plugins[authn_type]
265
+ auth_plugin_class = get_cli_plugin_class(authn_type)
266
266
  return auth_plugin_class(config_path)
267
- except KeyError:
267
+ except ValueError:
268
268
  typer.echo(f"❌ Unknown account authentication type: {authn_type}")
269
269
  raise typer.Exit(code=1) from None
270
- except ImportError:
271
- typer.echo("❌ No authentication plugins are currently supported.")
272
- raise typer.Exit(code=1) from None
273
270
 
274
271
 
275
272
  def init_channel(
276
- app: Path, federation_config: dict[str, Any], auth_plugin: Optional[CliAuthPlugin]
273
+ app: Path, federation_config: dict[str, Any], auth_plugin: CliAuthPlugin
277
274
  ) -> grpc.Channel:
278
275
  """Initialize gRPC channel to the Control API."""
279
276
  insecure, root_certificates_bytes = validate_certificate_in_federation_config(
280
277
  app, federation_config
281
278
  )
282
279
 
283
- # Initialize the CLI-side account auth interceptor
284
- interceptors: list[grpc.UnaryUnaryClientInterceptor] = []
285
- if auth_plugin is not None:
286
- # Check if TLS is enabled. If not, raise an error
287
- if insecure:
288
- typer.secho(
289
- "❌ Account authentication requires TLS to be enabled. "
290
- "Remove `insecure = true` from the federation configuration.",
291
- fg=typer.colors.RED,
292
- bold=True,
293
- )
294
- raise typer.Exit(code=1)
295
-
296
- auth_plugin.load_tokens()
297
- interceptors.append(CliAccountAuthInterceptor(auth_plugin))
280
+ # Load tokens
281
+ auth_plugin.load_tokens()
298
282
 
299
283
  # Create the gRPC channel
300
284
  channel = create_channel(
@@ -302,7 +286,7 @@ def init_channel(
302
286
  insecure=insecure,
303
287
  root_certificates=root_certificates_bytes,
304
288
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
305
- interceptors=interceptors or None,
289
+ interceptors=[CliAccountAuthInterceptor(auth_plugin)],
306
290
  )
307
291
  channel.subscribe(on_channel_state_change)
308
292
  return channel
@@ -55,10 +55,10 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
55
55
  from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
56
56
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
57
57
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
58
- from flwr.supercore.primitives.asymmetric import generate_key_pairs
58
+ from flwr.supercore.primitives.asymmetric import generate_key_pairs, public_key_to_bytes
59
59
 
60
- from .client_interceptor import AuthenticateClientInterceptor
61
60
  from .grpc_adapter import GrpcAdapter
61
+ from .node_auth_client_interceptor import NodeAuthClientInterceptor
62
62
 
63
63
 
64
64
  @contextmanager
@@ -138,8 +138,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
138
138
 
139
139
  # Always configure auth interceptor, with either user-provided or generated keys
140
140
  interceptors: Sequence[grpc.UnaryUnaryClientInterceptor] = [
141
- AuthenticateClientInterceptor(*authentication_keys),
141
+ NodeAuthClientInterceptor(*authentication_keys),
142
142
  ]
143
+ node_pk = public_key_to_bytes(authentication_keys[1])
143
144
  channel = create_channel(
144
145
  server_address=server_address,
145
146
  insecure=insecure,
@@ -199,7 +200,8 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
199
200
  """Set create_node."""
200
201
  # Call FleetAPI
201
202
  create_node_request = CreateNodeRequest(
202
- heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
203
+ public_key=node_pk,
204
+ heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
203
205
  )
204
206
  create_node_response = stub.CreateNode(request=create_node_request)
205
207
 
@@ -26,8 +26,8 @@ from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_
26
26
  from flwr.supercore.primitives.asymmetric import public_key_to_bytes, sign_message
27
27
 
28
28
 
29
- class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
30
- """Client interceptor for client authentication."""
29
+ class NodeAuthClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
30
+ """Client interceptor for node authentication."""
31
31
 
32
32
  def __init__(
33
33
  self,
@@ -15,6 +15,7 @@
15
15
  """Contextmanager for a REST request-response channel to the Flower server."""
16
16
 
17
17
 
18
+ import secrets
18
19
  from collections.abc import Iterator
19
20
  from contextlib import contextmanager
20
21
  from logging import ERROR, WARN
@@ -292,7 +293,12 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
292
293
 
293
294
  def create_node() -> Optional[int]:
294
295
  """Set create_node."""
295
- req = CreateNodeRequest(heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL)
296
+ req = CreateNodeRequest(
297
+ # REST does not support node authentication;
298
+ # random bytes are used instead
299
+ public_key=secrets.token_bytes(32),
300
+ heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
301
+ )
296
302
 
297
303
  # Send the request
298
304
  res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
flwr/common/constant.py CHANGED
@@ -256,6 +256,16 @@ class AuthnType:
256
256
  raise TypeError(f"{cls.__name__} cannot be instantiated.")
257
257
 
258
258
 
259
+ class AuthzType:
260
+ """Account authorization types."""
261
+
262
+ NOOP = "noop"
263
+
264
+ def __new__(cls) -> AuthzType:
265
+ """Prevent instantiation."""
266
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
267
+
268
+
259
269
  class EventLogWriterType:
260
270
  """Event log writer types."""
261
271
 
flwr/proto/fleet_pb2.py CHANGED
@@ -19,7 +19,7 @@ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
19
19
  from flwr.proto import message_pb2 as flwr_dot_proto_dot_message__pb2
20
20
 
21
21
 
22
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x18\x66lwr/proto/message.proto\"/\n\x11\x43reateNodeRequest\x12\x1a\n\x12heartbeat_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"J\n\x13PullMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x13\n\x0bmessage_ids\x18\x02 \x03(\t\"\xa2\x01\n\x14PullMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\x97\x01\n\x13PushMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\xc9\x01\n\x14PushMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12>\n\x07results\x18\x02 \x03(\x0b\x32-.flwr.proto.PushMessagesResponse.ResultsEntry\x12\x17\n\x0fobjects_to_push\x18\x03 \x03(\t\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xca\x06\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12\x62\n\x11SendNodeHeartbeat\x12$.flwr.proto.SendNodeHeartbeatRequest\x1a%.flwr.proto.SendNodeHeartbeatResponse\"\x00\x12S\n\x0cPullMessages\x12\x1f.flwr.proto.PullMessagesRequest\x1a .flwr.proto.PullMessagesResponse\"\x00\x12S\n\x0cPushMessages\x12\x1f.flwr.proto.PushMessagesRequest\x1a .flwr.proto.PushMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x12q\n\x16\x43onfirmMessageReceived\x12).flwr.proto.ConfirmMessageReceivedRequest\x1a*.flwr.proto.ConfirmMessageReceivedResponse\"\x00\x62\x06proto3')
22
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x18\x66lwr/proto/message.proto\"C\n\x11\x43reateNodeRequest\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x1a\n\x12heartbeat_interval\x18\x02 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"J\n\x13PullMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x13\n\x0bmessage_ids\x18\x02 \x03(\t\"\xa2\x01\n\x14PullMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\x97\x01\n\x13PushMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\xc9\x01\n\x14PushMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12>\n\x07results\x18\x02 \x03(\x0b\x32-.flwr.proto.PushMessagesResponse.ResultsEntry\x12\x17\n\x0fobjects_to_push\x18\x03 \x03(\t\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xca\x06\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12\x62\n\x11SendNodeHeartbeat\x12$.flwr.proto.SendNodeHeartbeatRequest\x1a%.flwr.proto.SendNodeHeartbeatResponse\"\x00\x12S\n\x0cPullMessages\x12\x1f.flwr.proto.PullMessagesRequest\x1a .flwr.proto.PullMessagesResponse\"\x00\x12S\n\x0cPushMessages\x12\x1f.flwr.proto.PushMessagesRequest\x1a .flwr.proto.PushMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x12q\n\x16\x43onfirmMessageReceived\x12).flwr.proto.ConfirmMessageReceivedRequest\x1a*.flwr.proto.ConfirmMessageReceivedResponse\"\x00\x62\x06proto3')
23
23
 
24
24
  _globals = globals()
25
25
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -29,25 +29,25 @@ if _descriptor._USE_C_DESCRIPTORS == False:
29
29
  _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._options = None
30
30
  _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001'
31
31
  _globals['_CREATENODEREQUEST']._serialized_start=159
32
- _globals['_CREATENODEREQUEST']._serialized_end=206
33
- _globals['_CREATENODERESPONSE']._serialized_start=208
34
- _globals['_CREATENODERESPONSE']._serialized_end=260
35
- _globals['_DELETENODEREQUEST']._serialized_start=262
36
- _globals['_DELETENODEREQUEST']._serialized_end=313
37
- _globals['_DELETENODERESPONSE']._serialized_start=315
38
- _globals['_DELETENODERESPONSE']._serialized_end=335
39
- _globals['_PULLMESSAGESREQUEST']._serialized_start=337
40
- _globals['_PULLMESSAGESREQUEST']._serialized_end=411
41
- _globals['_PULLMESSAGESRESPONSE']._serialized_start=414
42
- _globals['_PULLMESSAGESRESPONSE']._serialized_end=576
43
- _globals['_PUSHMESSAGESREQUEST']._serialized_start=579
44
- _globals['_PUSHMESSAGESREQUEST']._serialized_end=730
45
- _globals['_PUSHMESSAGESRESPONSE']._serialized_start=733
46
- _globals['_PUSHMESSAGESRESPONSE']._serialized_end=934
47
- _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_start=888
48
- _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_end=934
49
- _globals['_RECONNECT']._serialized_start=936
50
- _globals['_RECONNECT']._serialized_end=966
51
- _globals['_FLEET']._serialized_start=969
52
- _globals['_FLEET']._serialized_end=1811
32
+ _globals['_CREATENODEREQUEST']._serialized_end=226
33
+ _globals['_CREATENODERESPONSE']._serialized_start=228
34
+ _globals['_CREATENODERESPONSE']._serialized_end=280
35
+ _globals['_DELETENODEREQUEST']._serialized_start=282
36
+ _globals['_DELETENODEREQUEST']._serialized_end=333
37
+ _globals['_DELETENODERESPONSE']._serialized_start=335
38
+ _globals['_DELETENODERESPONSE']._serialized_end=355
39
+ _globals['_PULLMESSAGESREQUEST']._serialized_start=357
40
+ _globals['_PULLMESSAGESREQUEST']._serialized_end=431
41
+ _globals['_PULLMESSAGESRESPONSE']._serialized_start=434
42
+ _globals['_PULLMESSAGESRESPONSE']._serialized_end=596
43
+ _globals['_PUSHMESSAGESREQUEST']._serialized_start=599
44
+ _globals['_PUSHMESSAGESREQUEST']._serialized_end=750
45
+ _globals['_PUSHMESSAGESRESPONSE']._serialized_start=753
46
+ _globals['_PUSHMESSAGESRESPONSE']._serialized_end=954
47
+ _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_start=908
48
+ _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_end=954
49
+ _globals['_RECONNECT']._serialized_start=956
50
+ _globals['_RECONNECT']._serialized_end=986
51
+ _globals['_FLEET']._serialized_start=989
52
+ _globals['_FLEET']._serialized_end=1831
53
53
  # @@protoc_insertion_point(module_scope)
flwr/proto/fleet_pb2.pyi CHANGED
@@ -16,13 +16,16 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
16
16
  class CreateNodeRequest(google.protobuf.message.Message):
17
17
  """CreateNode messages"""
18
18
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
19
+ PUBLIC_KEY_FIELD_NUMBER: builtins.int
19
20
  HEARTBEAT_INTERVAL_FIELD_NUMBER: builtins.int
21
+ public_key: builtins.bytes
20
22
  heartbeat_interval: builtins.float
21
23
  def __init__(self,
22
24
  *,
25
+ public_key: builtins.bytes = ...,
23
26
  heartbeat_interval: builtins.float = ...,
24
27
  ) -> None: ...
25
- def ClearField(self, field_name: typing_extensions.Literal["heartbeat_interval",b"heartbeat_interval"]) -> None: ...
28
+ def ClearField(self, field_name: typing_extensions.Literal["heartbeat_interval",b"heartbeat_interval","public_key",b"public_key"]) -> None: ...
26
29
  global___CreateNodeRequest = CreateNodeRequest
27
30
 
28
31
  class CreateNodeResponse(google.protobuf.message.Message):
flwr/proto/node_pb2.py CHANGED
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
14
14
 
15
15
 
16
16
 
17
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/node.proto\x12\nflwr.proto\"\x17\n\x04Node\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\xd0\x01\n\x08NodeInfo\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\x12\x11\n\towner_aid\x18\x02 \x01(\t\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x12\n\ncreated_at\x18\x04 \x01(\t\x12\x19\n\x11last_activated_at\x18\x05 \x01(\t\x12\x1b\n\x13last_deactivated_at\x18\x06 \x01(\t\x12\x12\n\ndeleted_at\x18\x07 \x01(\t\x12\x14\n\x0conline_until\x18\x08 \x01(\x02\x12\x1a\n\x12heartbeat_interval\x18\t \x01(\x02\x62\x06proto3')
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/node.proto\x12\nflwr.proto\"\x17\n\x04Node\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\xd0\x01\n\x08NodeInfo\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\x12\x11\n\towner_aid\x18\x02 \x01(\t\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x12\n\ncreated_at\x18\x04 \x01(\t\x12\x19\n\x11last_activated_at\x18\x05 \x01(\t\x12\x1b\n\x13last_deactivated_at\x18\x06 \x01(\t\x12\x12\n\ndeleted_at\x18\x07 \x01(\t\x12\x14\n\x0conline_until\x18\x08 \x01(\x01\x12\x1a\n\x12heartbeat_interval\x18\t \x01(\x01\x62\x06proto3')
18
18
 
19
19
  _globals = globals()
20
20
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
flwr/server/app.py CHANGED
@@ -26,7 +26,7 @@ from collections.abc import Sequence
26
26
  from logging import DEBUG, INFO, WARN
27
27
  from pathlib import Path
28
28
  from time import sleep
29
- from typing import Any, Callable, Optional, TypeVar
29
+ from typing import Callable, Optional, TypeVar, cast
30
30
 
31
31
  import grpc
32
32
  import yaml
@@ -52,6 +52,8 @@ from flwr.common.constant import (
52
52
  TRANSPORT_TYPE_GRPC_ADAPTER,
53
53
  TRANSPORT_TYPE_GRPC_RERE,
54
54
  TRANSPORT_TYPE_REST,
55
+ AuthnType,
56
+ AuthzType,
55
57
  EventLogWriterType,
56
58
  ExecPluginType,
57
59
  )
@@ -69,12 +71,19 @@ from flwr.supercore.grpc_health import add_args_health, run_health_server_grpc_n
69
71
  from flwr.supercore.object_store import ObjectStoreFactory
70
72
  from flwr.supercore.primitives.asymmetric import public_key_to_bytes
71
73
  from flwr.superlink.artifact_provider import ArtifactProvider
72
- from flwr.superlink.auth_plugin import ControlAuthnPlugin, ControlAuthzPlugin
74
+ from flwr.superlink.auth_plugin import (
75
+ ControlAuthnPlugin,
76
+ ControlAuthzPlugin,
77
+ get_control_authn_plugins,
78
+ get_control_authz_plugins,
79
+ )
73
80
  from flwr.superlink.servicer.control import run_control_api_grpc
74
81
 
75
82
  from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
76
83
  from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
77
- from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
84
+ from .superlink.fleet.grpc_rere.node_auth_server_interceptor import (
85
+ NodeAuthServerInterceptor,
86
+ )
78
87
  from .superlink.linkstate import LinkStateFactory
79
88
  from .superlink.serverappio.serverappio_grpc import run_serverappio_api_grpc
80
89
  from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
@@ -87,8 +96,6 @@ P = TypeVar("P", ControlAuthnPlugin, ControlAuthzPlugin)
87
96
  try:
88
97
  from flwr.ee import (
89
98
  add_ee_args_superlink,
90
- get_control_authn_plugins,
91
- get_control_authz_plugins,
92
99
  get_control_event_log_writer_plugins,
93
100
  get_ee_artifact_provider,
94
101
  get_fleet_event_log_writer_plugins,
@@ -99,14 +106,6 @@ except ImportError:
99
106
  def add_ee_args_superlink(parser: argparse.ArgumentParser) -> None:
100
107
  """Add EE-specific arguments to the parser."""
101
108
 
102
- def get_control_authn_plugins() -> dict[str, type[ControlAuthnPlugin]]:
103
- """Return all Control API authentication plugins."""
104
- raise NotImplementedError("No authentication plugins are currently supported.")
105
-
106
- def get_control_authz_plugins() -> dict[str, type[ControlAuthzPlugin]]:
107
- """Return all Control API authorization plugins."""
108
- raise NotImplementedError("No authorization plugins are currently supported.")
109
-
110
109
  def get_control_event_log_writer_plugins() -> dict[str, type[EventLogWriterPlugin]]:
111
110
  """Return all Control API event log writer plugins."""
112
111
  raise NotImplementedError(
@@ -202,10 +201,9 @@ def run_superlink() -> None:
202
201
  "future release. Please use `--account-auth-config` instead.",
203
202
  )
204
203
  args.account_auth_config = cfg_path
205
- if cfg_path := getattr(args, "account_auth_config", None):
206
- authn_plugin, authz_plugin = _try_obtain_control_auth_plugins(
207
- Path(cfg_path), verify_tls_cert
208
- )
204
+ cfg_path = getattr(args, "account_auth_config", None)
205
+ authn_plugin, authz_plugin = _load_control_auth_plugins(cfg_path, verify_tls_cert)
206
+ if cfg_path is not None:
209
207
  # Enable event logging if the args.enable_event_log is True
210
208
  if args.enable_event_log:
211
209
  event_log_plugin = _try_obtain_control_event_log_writer_plugin()
@@ -326,7 +324,7 @@ def run_superlink() -> None:
326
324
  else:
327
325
  log(DEBUG, "Automatic node authentication enabled")
328
326
 
329
- interceptors = [AuthenticateServerInterceptor(state_factory, auto_auth)]
327
+ interceptors = [NodeAuthServerInterceptor(state_factory, auto_auth)]
330
328
  if getattr(args, "enable_event_log", None):
331
329
  fleet_log_plugin = _try_obtain_fleet_event_log_writer_plugin()
332
330
  if fleet_log_plugin is not None:
@@ -447,13 +445,21 @@ def _try_load_public_keys_node_authentication(
447
445
  return node_public_keys
448
446
 
449
447
 
450
- def _try_obtain_control_auth_plugins(
451
- config_path: Path, verify_tls_cert: bool
448
+ def _load_control_auth_plugins(
449
+ config_path: Optional[str], verify_tls_cert: bool
452
450
  ) -> tuple[ControlAuthnPlugin, ControlAuthzPlugin]:
453
451
  """Obtain Control API authentication and authorization plugins."""
452
+ # Load NoOp plugins if no config path is provided
453
+ if config_path is None:
454
+ config_path = ""
455
+ config = {
456
+ "authentication": {AUTHN_TYPE_YAML_KEY: AuthnType.NOOP},
457
+ "authorization": {AUTHZ_TYPE_YAML_KEY: AuthzType.NOOP},
458
+ }
454
459
  # Load YAML file
455
- with config_path.open("r", encoding="utf-8") as file:
456
- config: dict[str, Any] = yaml.safe_load(file)
460
+ else:
461
+ with Path(config_path).open("r", encoding="utf-8") as file:
462
+ config = yaml.safe_load(file)
457
463
 
458
464
  def _load_plugin(
459
465
  section: str, yaml_key: str, loader: Callable[[], dict[str, type[P]]]
@@ -463,9 +469,7 @@ def _try_obtain_control_auth_plugins(
463
469
  try:
464
470
  plugins: dict[str, type[P]] = loader()
465
471
  plugin_cls: type[P] = plugins[auth_plugin_name]
466
- return plugin_cls(
467
- account_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
468
- )
472
+ return plugin_cls(Path(cast(str, config_path)), verify_tls_cert)
469
473
  except KeyError:
470
474
  if auth_plugin_name:
471
475
  sys.exit(
@@ -473,18 +477,15 @@ def _try_obtain_control_auth_plugins(
473
477
  f"Please provide a valid {section} type in the configuration."
474
478
  )
475
479
  sys.exit(f"No {section} type is provided in the configuration.")
476
- except NotImplementedError:
477
- sys.exit(f"No {section} plugins are currently supported.")
478
480
 
479
- # Warn deprecated authn_type key
480
- if "authn_type" in config["authentication"]:
481
+ # Warn deprecated auth_type key
482
+ if authn_type := config["authentication"].pop("auth_type", None):
481
483
  log(
482
484
  WARN,
483
- "The `authn_type` key in the authentication configuration is deprecated. "
485
+ "The `auth_type` key in the authentication configuration is deprecated. "
484
486
  "Use `%s` instead.",
485
487
  AUTHN_TYPE_YAML_KEY,
486
488
  )
487
- authn_type = config["authentication"].pop("authn_type")
488
489
  config["authentication"][AUTHN_TYPE_YAML_KEY] = authn_type
489
490
 
490
491
  # Load authentication plugin
@@ -78,10 +78,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
78
78
  request.heartbeat_interval,
79
79
  )
80
80
  log(DEBUG, "[Fleet.CreateNode] Request: %s", MessageToDict(request))
81
- response = message_handler.create_node(
82
- request=request,
83
- state=self.state_factory.state(),
84
- )
81
+ try:
82
+ response = message_handler.create_node(
83
+ request=request,
84
+ state=self.state_factory.state(),
85
+ )
86
+ except ValueError as e:
87
+ # Public key already in use
88
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
85
89
  log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
86
90
  log(DEBUG, "[Fleet.CreateNode] Response: %s", MessageToDict(response))
87
91
  return response