flwr-nightly 1.23.0.dev20251007__py3-none-any.whl → 1.23.0.dev20251009__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. flwr/cli/auth_plugin/__init__.py +7 -3
  2. flwr/cli/log.py +2 -2
  3. flwr/cli/login/login.py +4 -13
  4. flwr/cli/ls.py +2 -2
  5. flwr/cli/pull.py +2 -2
  6. flwr/cli/run/run.py +2 -2
  7. flwr/cli/stop.py +2 -2
  8. flwr/cli/supernode/create.py +137 -11
  9. flwr/cli/supernode/delete.py +88 -10
  10. flwr/cli/supernode/ls.py +2 -2
  11. flwr/cli/utils.py +65 -55
  12. flwr/client/grpc_rere_client/connection.py +6 -4
  13. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +2 -2
  14. flwr/client/rest_client/connection.py +7 -1
  15. flwr/common/constant.py +13 -0
  16. flwr/proto/control_pb2.py +1 -1
  17. flwr/proto/control_pb2.pyi +2 -2
  18. flwr/proto/fleet_pb2.py +22 -22
  19. flwr/proto/fleet_pb2.pyi +4 -1
  20. flwr/proto/node_pb2.py +2 -2
  21. flwr/proto/node_pb2.pyi +4 -1
  22. flwr/server/app.py +32 -31
  23. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +8 -4
  24. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +18 -37
  25. flwr/server/superlink/fleet/message_handler/message_handler.py +5 -3
  26. flwr/server/superlink/fleet/vce/vce_api.py +10 -1
  27. flwr/server/superlink/linkstate/in_memory_linkstate.py +52 -54
  28. flwr/server/superlink/linkstate/linkstate.py +20 -10
  29. flwr/server/superlink/linkstate/sqlite_linkstate.py +54 -61
  30. flwr/server/utils/validator.py +2 -3
  31. flwr/supercore/primitives/asymmetric.py +8 -0
  32. flwr/superlink/auth_plugin/__init__.py +29 -0
  33. flwr/superlink/servicer/control/control_grpc.py +9 -7
  34. flwr/superlink/servicer/control/control_servicer.py +89 -48
  35. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/METADATA +1 -1
  36. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/RECORD +38 -38
  37. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/WHEEL +0 -0
  38. {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/entry_points.txt +0 -0
@@ -55,10 +55,10 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
55
55
  from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
56
56
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
57
57
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
58
- from flwr.supercore.primitives.asymmetric import generate_key_pairs
58
+ from flwr.supercore.primitives.asymmetric import generate_key_pairs, public_key_to_bytes
59
59
 
60
- from .client_interceptor import AuthenticateClientInterceptor
61
60
  from .grpc_adapter import GrpcAdapter
61
+ from .node_auth_client_interceptor import NodeAuthClientInterceptor
62
62
 
63
63
 
64
64
  @contextmanager
@@ -138,8 +138,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
138
138
 
139
139
  # Always configure auth interceptor, with either user-provided or generated keys
140
140
  interceptors: Sequence[grpc.UnaryUnaryClientInterceptor] = [
141
- AuthenticateClientInterceptor(*authentication_keys),
141
+ NodeAuthClientInterceptor(*authentication_keys),
142
142
  ]
143
+ node_pk = public_key_to_bytes(authentication_keys[1])
143
144
  channel = create_channel(
144
145
  server_address=server_address,
145
146
  insecure=insecure,
@@ -199,7 +200,8 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
199
200
  """Set create_node."""
200
201
  # Call FleetAPI
201
202
  create_node_request = CreateNodeRequest(
202
- heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
203
+ public_key=node_pk,
204
+ heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
203
205
  )
204
206
  create_node_response = stub.CreateNode(request=create_node_request)
205
207
 
@@ -26,8 +26,8 @@ from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_
26
26
  from flwr.supercore.primitives.asymmetric import public_key_to_bytes, sign_message
27
27
 
28
28
 
29
- class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
30
- """Client interceptor for client authentication."""
29
+ class NodeAuthClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
30
+ """Client interceptor for node authentication."""
31
31
 
32
32
  def __init__(
33
33
  self,
@@ -15,6 +15,7 @@
15
15
  """Contextmanager for a REST request-response channel to the Flower server."""
16
16
 
17
17
 
18
+ import secrets
18
19
  from collections.abc import Iterator
19
20
  from contextlib import contextmanager
20
21
  from logging import ERROR, WARN
@@ -292,7 +293,12 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
292
293
 
293
294
  def create_node() -> Optional[int]:
294
295
  """Set create_node."""
295
- req = CreateNodeRequest(heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL)
296
+ req = CreateNodeRequest(
297
+ # REST does not support node authentication;
298
+ # random bytes are used instead
299
+ public_key=secrets.token_bytes(32),
300
+ heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
301
+ )
296
302
 
297
303
  # Send the request
298
304
  res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
flwr/common/constant.py CHANGED
@@ -157,6 +157,9 @@ RUN_ID_NOT_FOUND_MESSAGE = "Run ID not found"
157
157
  NO_ACCOUNT_AUTH_MESSAGE = "ControlServicer initialized without account authentication"
158
158
  NO_ARTIFACT_PROVIDER_MESSAGE = "ControlServicer initialized without artifact provider"
159
159
  PULL_UNFINISHED_RUN_MESSAGE = "Cannot pull artifacts for an unfinished run"
160
+ PUBLIC_KEY_ALREADY_IN_USE_MESSAGE = "Public key already in use"
161
+ PUBLIC_KEY_NOT_VALID = "The provided public key is not valid"
162
+ NODE_NOT_FOUND_MESSAGE = "Node ID not found for account"
160
163
 
161
164
 
162
165
  class MessageType:
@@ -256,6 +259,16 @@ class AuthnType:
256
259
  raise TypeError(f"{cls.__name__} cannot be instantiated.")
257
260
 
258
261
 
262
+ class AuthzType:
263
+ """Account authorization types."""
264
+
265
+ NOOP = "noop"
266
+
267
+ def __new__(cls) -> AuthzType:
268
+ """Prevent instantiation."""
269
+ raise TypeError(f"{cls.__name__} cannot be instantiated.")
270
+
271
+
259
272
  class EventLogWriterType:
260
273
  """Event log writer types."""
261
274
 
flwr/proto/control_pb2.py CHANGED
@@ -19,7 +19,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
19
19
  from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
20
20
 
21
21
 
22
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x15\x66lwr/proto/node.proto\"\xfa\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x18.flwr.proto.ConfigRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8b\x01\n\x17GetLoginDetailsResponse\x12\x12\n\nauthn_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"&\n\x14PullArtifactsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"1\n\x15PullArtifactsResponse\x12\x10\n\x03url\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x06\n\x04_url\"*\n\x14\x43reateNodeCliRequest\x12\x12\n\npublic_key\x18\x01 \x01(\t\"9\n\x15\x43reateNodeCliResponse\x12\x14\n\x07node_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\n\n\x08_node_id\"\'\n\x14\x44\x65leteNodeCliRequest\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\x17\n\x15\x44\x65leteNodeCliResponse\"\x15\n\x13ListNodesCliRequest\"M\n\x14ListNodesCliResponse\x12(\n\nnodes_info\x18\x01 \x03(\x0b\x32\x14.flwr.proto.NodeInfo\x12\x0b\n\x03now\x18\x02 \x01(\t2\xc5\x06\n\x07\x43ontrol\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x12\\\n\x0fGetLoginDetails\x12\".flwr.proto.GetLoginDetailsRequest\x1a#.flwr.proto.GetLoginDetailsResponse\"\x00\x12V\n\rGetAuthTokens\x12 .flwr.proto.GetAuthTokensRequest\x1a!.flwr.proto.GetAuthTokensResponse\"\x00\x12V\n\rPullArtifacts\x12 .flwr.proto.PullArtifactsRequest\x1a!.flwr.proto.PullArtifactsResponse\"\x00\x12V\n\rCreateNodeCli\x12 .flwr.proto.CreateNodeCliRequest\x1a!.flwr.proto.CreateNodeCliResponse\"\x00\x12V\n\rDeleteNodeCli\x12 .flwr.proto.DeleteNodeCliRequest\x1a!.flwr.proto.DeleteNodeCliResponse\"\x00\x12S\n\x0cListNodesCli\x12\x1f.flwr.proto.ListNodesCliRequest\x1a .flwr.proto.ListNodesCliResponse\"\x00\x62\x06proto3')
22
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x15\x66lwr/proto/node.proto\"\xfa\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x18.flwr.proto.ConfigRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8b\x01\n\x17GetLoginDetailsResponse\x12\x12\n\nauthn_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"&\n\x14PullArtifactsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"1\n\x15PullArtifactsResponse\x12\x10\n\x03url\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x06\n\x04_url\"*\n\x14\x43reateNodeCliRequest\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\"9\n\x15\x43reateNodeCliResponse\x12\x14\n\x07node_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\n\n\x08_node_id\"\'\n\x14\x44\x65leteNodeCliRequest\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\x17\n\x15\x44\x65leteNodeCliResponse\"\x15\n\x13ListNodesCliRequest\"M\n\x14ListNodesCliResponse\x12(\n\nnodes_info\x18\x01 \x03(\x0b\x32\x14.flwr.proto.NodeInfo\x12\x0b\n\x03now\x18\x02 \x01(\t2\xc5\x06\n\x07\x43ontrol\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x12\\\n\x0fGetLoginDetails\x12\".flwr.proto.GetLoginDetailsRequest\x1a#.flwr.proto.GetLoginDetailsResponse\"\x00\x12V\n\rGetAuthTokens\x12 .flwr.proto.GetAuthTokensRequest\x1a!.flwr.proto.GetAuthTokensResponse\"\x00\x12V\n\rPullArtifacts\x12 .flwr.proto.PullArtifactsRequest\x1a!.flwr.proto.PullArtifactsResponse\"\x00\x12V\n\rCreateNodeCli\x12 .flwr.proto.CreateNodeCliRequest\x1a!.flwr.proto.CreateNodeCliResponse\"\x00\x12V\n\rDeleteNodeCli\x12 .flwr.proto.DeleteNodeCliRequest\x1a!.flwr.proto.DeleteNodeCliResponse\"\x00\x12S\n\x0cListNodesCli\x12\x1f.flwr.proto.ListNodesCliRequest\x1a .flwr.proto.ListNodesCliResponse\"\x00\x62\x06proto3')
23
23
 
24
24
  _globals = globals()
25
25
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -239,10 +239,10 @@ global___PullArtifactsResponse = PullArtifactsResponse
239
239
  class CreateNodeCliRequest(google.protobuf.message.Message):
240
240
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
241
241
  PUBLIC_KEY_FIELD_NUMBER: builtins.int
242
- public_key: typing.Text
242
+ public_key: builtins.bytes
243
243
  def __init__(self,
244
244
  *,
245
- public_key: typing.Text = ...,
245
+ public_key: builtins.bytes = ...,
246
246
  ) -> None: ...
247
247
  def ClearField(self, field_name: typing_extensions.Literal["public_key",b"public_key"]) -> None: ...
248
248
  global___CreateNodeCliRequest = CreateNodeCliRequest
flwr/proto/fleet_pb2.py CHANGED
@@ -19,7 +19,7 @@ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
19
19
  from flwr.proto import message_pb2 as flwr_dot_proto_dot_message__pb2
20
20
 
21
21
 
22
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x18\x66lwr/proto/message.proto\"/\n\x11\x43reateNodeRequest\x12\x1a\n\x12heartbeat_interval\x18\x01 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"J\n\x13PullMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x13\n\x0bmessage_ids\x18\x02 \x03(\t\"\xa2\x01\n\x14PullMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\x97\x01\n\x13PushMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\xc9\x01\n\x14PushMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12>\n\x07results\x18\x02 \x03(\x0b\x32-.flwr.proto.PushMessagesResponse.ResultsEntry\x12\x17\n\x0fobjects_to_push\x18\x03 \x03(\t\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xca\x06\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12\x62\n\x11SendNodeHeartbeat\x12$.flwr.proto.SendNodeHeartbeatRequest\x1a%.flwr.proto.SendNodeHeartbeatResponse\"\x00\x12S\n\x0cPullMessages\x12\x1f.flwr.proto.PullMessagesRequest\x1a .flwr.proto.PullMessagesResponse\"\x00\x12S\n\x0cPushMessages\x12\x1f.flwr.proto.PushMessagesRequest\x1a .flwr.proto.PushMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x12q\n\x16\x43onfirmMessageReceived\x12).flwr.proto.ConfirmMessageReceivedRequest\x1a*.flwr.proto.ConfirmMessageReceivedResponse\"\x00\x62\x06proto3')
22
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x18\x66lwr/proto/message.proto\"C\n\x11\x43reateNodeRequest\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x1a\n\x12heartbeat_interval\x18\x02 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"J\n\x13PullMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x13\n\x0bmessage_ids\x18\x02 \x03(\t\"\xa2\x01\n\x14PullMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\x97\x01\n\x13PushMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\xc9\x01\n\x14PushMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12>\n\x07results\x18\x02 \x03(\x0b\x32-.flwr.proto.PushMessagesResponse.ResultsEntry\x12\x17\n\x0fobjects_to_push\x18\x03 \x03(\t\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xca\x06\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12\x62\n\x11SendNodeHeartbeat\x12$.flwr.proto.SendNodeHeartbeatRequest\x1a%.flwr.proto.SendNodeHeartbeatResponse\"\x00\x12S\n\x0cPullMessages\x12\x1f.flwr.proto.PullMessagesRequest\x1a .flwr.proto.PullMessagesResponse\"\x00\x12S\n\x0cPushMessages\x12\x1f.flwr.proto.PushMessagesRequest\x1a .flwr.proto.PushMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x12q\n\x16\x43onfirmMessageReceived\x12).flwr.proto.ConfirmMessageReceivedRequest\x1a*.flwr.proto.ConfirmMessageReceivedResponse\"\x00\x62\x06proto3')
23
23
 
24
24
  _globals = globals()
25
25
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -29,25 +29,25 @@ if _descriptor._USE_C_DESCRIPTORS == False:
29
29
  _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._options = None
30
30
  _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001'
31
31
  _globals['_CREATENODEREQUEST']._serialized_start=159
32
- _globals['_CREATENODEREQUEST']._serialized_end=206
33
- _globals['_CREATENODERESPONSE']._serialized_start=208
34
- _globals['_CREATENODERESPONSE']._serialized_end=260
35
- _globals['_DELETENODEREQUEST']._serialized_start=262
36
- _globals['_DELETENODEREQUEST']._serialized_end=313
37
- _globals['_DELETENODERESPONSE']._serialized_start=315
38
- _globals['_DELETENODERESPONSE']._serialized_end=335
39
- _globals['_PULLMESSAGESREQUEST']._serialized_start=337
40
- _globals['_PULLMESSAGESREQUEST']._serialized_end=411
41
- _globals['_PULLMESSAGESRESPONSE']._serialized_start=414
42
- _globals['_PULLMESSAGESRESPONSE']._serialized_end=576
43
- _globals['_PUSHMESSAGESREQUEST']._serialized_start=579
44
- _globals['_PUSHMESSAGESREQUEST']._serialized_end=730
45
- _globals['_PUSHMESSAGESRESPONSE']._serialized_start=733
46
- _globals['_PUSHMESSAGESRESPONSE']._serialized_end=934
47
- _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_start=888
48
- _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_end=934
49
- _globals['_RECONNECT']._serialized_start=936
50
- _globals['_RECONNECT']._serialized_end=966
51
- _globals['_FLEET']._serialized_start=969
52
- _globals['_FLEET']._serialized_end=1811
32
+ _globals['_CREATENODEREQUEST']._serialized_end=226
33
+ _globals['_CREATENODERESPONSE']._serialized_start=228
34
+ _globals['_CREATENODERESPONSE']._serialized_end=280
35
+ _globals['_DELETENODEREQUEST']._serialized_start=282
36
+ _globals['_DELETENODEREQUEST']._serialized_end=333
37
+ _globals['_DELETENODERESPONSE']._serialized_start=335
38
+ _globals['_DELETENODERESPONSE']._serialized_end=355
39
+ _globals['_PULLMESSAGESREQUEST']._serialized_start=357
40
+ _globals['_PULLMESSAGESREQUEST']._serialized_end=431
41
+ _globals['_PULLMESSAGESRESPONSE']._serialized_start=434
42
+ _globals['_PULLMESSAGESRESPONSE']._serialized_end=596
43
+ _globals['_PUSHMESSAGESREQUEST']._serialized_start=599
44
+ _globals['_PUSHMESSAGESREQUEST']._serialized_end=750
45
+ _globals['_PUSHMESSAGESRESPONSE']._serialized_start=753
46
+ _globals['_PUSHMESSAGESRESPONSE']._serialized_end=954
47
+ _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_start=908
48
+ _globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_end=954
49
+ _globals['_RECONNECT']._serialized_start=956
50
+ _globals['_RECONNECT']._serialized_end=986
51
+ _globals['_FLEET']._serialized_start=989
52
+ _globals['_FLEET']._serialized_end=1831
53
53
  # @@protoc_insertion_point(module_scope)
flwr/proto/fleet_pb2.pyi CHANGED
@@ -16,13 +16,16 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
16
16
  class CreateNodeRequest(google.protobuf.message.Message):
17
17
  """CreateNode messages"""
18
18
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
19
+ PUBLIC_KEY_FIELD_NUMBER: builtins.int
19
20
  HEARTBEAT_INTERVAL_FIELD_NUMBER: builtins.int
21
+ public_key: builtins.bytes
20
22
  heartbeat_interval: builtins.float
21
23
  def __init__(self,
22
24
  *,
25
+ public_key: builtins.bytes = ...,
23
26
  heartbeat_interval: builtins.float = ...,
24
27
  ) -> None: ...
25
- def ClearField(self, field_name: typing_extensions.Literal["heartbeat_interval",b"heartbeat_interval"]) -> None: ...
28
+ def ClearField(self, field_name: typing_extensions.Literal["heartbeat_interval",b"heartbeat_interval","public_key",b"public_key"]) -> None: ...
26
29
  global___CreateNodeRequest = CreateNodeRequest
27
30
 
28
31
  class CreateNodeResponse(google.protobuf.message.Message):
flwr/proto/node_pb2.py CHANGED
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
14
14
 
15
15
 
16
16
 
17
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/node.proto\x12\nflwr.proto\"\x17\n\x04Node\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\xd0\x01\n\x08NodeInfo\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\x12\x11\n\towner_aid\x18\x02 \x01(\t\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x12\n\ncreated_at\x18\x04 \x01(\t\x12\x19\n\x11last_activated_at\x18\x05 \x01(\t\x12\x1b\n\x13last_deactivated_at\x18\x06 \x01(\t\x12\x12\n\ndeleted_at\x18\x07 \x01(\t\x12\x14\n\x0conline_until\x18\x08 \x01(\x02\x12\x1a\n\x12heartbeat_interval\x18\t \x01(\x02\x62\x06proto3')
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/node.proto\x12\nflwr.proto\"\x17\n\x04Node\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\xe4\x01\n\x08NodeInfo\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\x12\x11\n\towner_aid\x18\x02 \x01(\t\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x12\n\ncreated_at\x18\x04 \x01(\t\x12\x19\n\x11last_activated_at\x18\x05 \x01(\t\x12\x1b\n\x13last_deactivated_at\x18\x06 \x01(\t\x12\x12\n\ndeleted_at\x18\x07 \x01(\t\x12\x14\n\x0conline_until\x18\x08 \x01(\x01\x12\x1a\n\x12heartbeat_interval\x18\t \x01(\x01\x12\x12\n\npublic_key\x18\n \x01(\x0c\x62\x06proto3')
18
18
 
19
19
  _globals = globals()
20
20
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -24,5 +24,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
24
24
  _globals['_NODE']._serialized_start=37
25
25
  _globals['_NODE']._serialized_end=60
26
26
  _globals['_NODEINFO']._serialized_start=63
27
- _globals['_NODEINFO']._serialized_end=271
27
+ _globals['_NODEINFO']._serialized_end=291
28
28
  # @@protoc_insertion_point(module_scope)
flwr/proto/node_pb2.pyi CHANGED
@@ -32,6 +32,7 @@ class NodeInfo(google.protobuf.message.Message):
32
32
  DELETED_AT_FIELD_NUMBER: builtins.int
33
33
  ONLINE_UNTIL_FIELD_NUMBER: builtins.int
34
34
  HEARTBEAT_INTERVAL_FIELD_NUMBER: builtins.int
35
+ PUBLIC_KEY_FIELD_NUMBER: builtins.int
35
36
  node_id: builtins.int
36
37
  owner_aid: typing.Text
37
38
  status: typing.Text
@@ -41,6 +42,7 @@ class NodeInfo(google.protobuf.message.Message):
41
42
  deleted_at: typing.Text
42
43
  online_until: builtins.float
43
44
  heartbeat_interval: builtins.float
45
+ public_key: builtins.bytes
44
46
  def __init__(self,
45
47
  *,
46
48
  node_id: builtins.int = ...,
@@ -52,6 +54,7 @@ class NodeInfo(google.protobuf.message.Message):
52
54
  deleted_at: typing.Text = ...,
53
55
  online_until: builtins.float = ...,
54
56
  heartbeat_interval: builtins.float = ...,
57
+ public_key: builtins.bytes = ...,
55
58
  ) -> None: ...
56
- def ClearField(self, field_name: typing_extensions.Literal["created_at",b"created_at","deleted_at",b"deleted_at","heartbeat_interval",b"heartbeat_interval","last_activated_at",b"last_activated_at","last_deactivated_at",b"last_deactivated_at","node_id",b"node_id","online_until",b"online_until","owner_aid",b"owner_aid","status",b"status"]) -> None: ...
59
+ def ClearField(self, field_name: typing_extensions.Literal["created_at",b"created_at","deleted_at",b"deleted_at","heartbeat_interval",b"heartbeat_interval","last_activated_at",b"last_activated_at","last_deactivated_at",b"last_deactivated_at","node_id",b"node_id","online_until",b"online_until","owner_aid",b"owner_aid","public_key",b"public_key","status",b"status"]) -> None: ...
57
60
  global___NodeInfo = NodeInfo
flwr/server/app.py CHANGED
@@ -26,7 +26,7 @@ from collections.abc import Sequence
26
26
  from logging import DEBUG, INFO, WARN
27
27
  from pathlib import Path
28
28
  from time import sleep
29
- from typing import Any, Callable, Optional, TypeVar
29
+ from typing import Callable, Optional, TypeVar, cast
30
30
 
31
31
  import grpc
32
32
  import yaml
@@ -52,6 +52,8 @@ from flwr.common.constant import (
52
52
  TRANSPORT_TYPE_GRPC_ADAPTER,
53
53
  TRANSPORT_TYPE_GRPC_RERE,
54
54
  TRANSPORT_TYPE_REST,
55
+ AuthnType,
56
+ AuthzType,
55
57
  EventLogWriterType,
56
58
  ExecPluginType,
57
59
  )
@@ -69,12 +71,19 @@ from flwr.supercore.grpc_health import add_args_health, run_health_server_grpc_n
69
71
  from flwr.supercore.object_store import ObjectStoreFactory
70
72
  from flwr.supercore.primitives.asymmetric import public_key_to_bytes
71
73
  from flwr.superlink.artifact_provider import ArtifactProvider
72
- from flwr.superlink.auth_plugin import ControlAuthnPlugin, ControlAuthzPlugin
74
+ from flwr.superlink.auth_plugin import (
75
+ ControlAuthnPlugin,
76
+ ControlAuthzPlugin,
77
+ get_control_authn_plugins,
78
+ get_control_authz_plugins,
79
+ )
73
80
  from flwr.superlink.servicer.control import run_control_api_grpc
74
81
 
75
82
  from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
76
83
  from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
77
- from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
84
+ from .superlink.fleet.grpc_rere.node_auth_server_interceptor import (
85
+ NodeAuthServerInterceptor,
86
+ )
78
87
  from .superlink.linkstate import LinkStateFactory
79
88
  from .superlink.serverappio.serverappio_grpc import run_serverappio_api_grpc
80
89
  from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
@@ -87,8 +96,6 @@ P = TypeVar("P", ControlAuthnPlugin, ControlAuthzPlugin)
87
96
  try:
88
97
  from flwr.ee import (
89
98
  add_ee_args_superlink,
90
- get_control_authn_plugins,
91
- get_control_authz_plugins,
92
99
  get_control_event_log_writer_plugins,
93
100
  get_ee_artifact_provider,
94
101
  get_fleet_event_log_writer_plugins,
@@ -99,14 +106,6 @@ except ImportError:
99
106
  def add_ee_args_superlink(parser: argparse.ArgumentParser) -> None:
100
107
  """Add EE-specific arguments to the parser."""
101
108
 
102
- def get_control_authn_plugins() -> dict[str, type[ControlAuthnPlugin]]:
103
- """Return all Control API authentication plugins."""
104
- raise NotImplementedError("No authentication plugins are currently supported.")
105
-
106
- def get_control_authz_plugins() -> dict[str, type[ControlAuthzPlugin]]:
107
- """Return all Control API authorization plugins."""
108
- raise NotImplementedError("No authorization plugins are currently supported.")
109
-
110
109
  def get_control_event_log_writer_plugins() -> dict[str, type[EventLogWriterPlugin]]:
111
110
  """Return all Control API event log writer plugins."""
112
111
  raise NotImplementedError(
@@ -202,10 +201,9 @@ def run_superlink() -> None:
202
201
  "future release. Please use `--account-auth-config` instead.",
203
202
  )
204
203
  args.account_auth_config = cfg_path
205
- if cfg_path := getattr(args, "account_auth_config", None):
206
- authn_plugin, authz_plugin = _try_obtain_control_auth_plugins(
207
- Path(cfg_path), verify_tls_cert
208
- )
204
+ cfg_path = getattr(args, "account_auth_config", None)
205
+ authn_plugin, authz_plugin = _load_control_auth_plugins(cfg_path, verify_tls_cert)
206
+ if cfg_path is not None:
209
207
  # Enable event logging if the args.enable_event_log is True
210
208
  if args.enable_event_log:
211
209
  event_log_plugin = _try_obtain_control_event_log_writer_plugin()
@@ -326,7 +324,7 @@ def run_superlink() -> None:
326
324
  else:
327
325
  log(DEBUG, "Automatic node authentication enabled")
328
326
 
329
- interceptors = [AuthenticateServerInterceptor(state_factory, auto_auth)]
327
+ interceptors = [NodeAuthServerInterceptor(state_factory, auto_auth)]
330
328
  if getattr(args, "enable_event_log", None):
331
329
  fleet_log_plugin = _try_obtain_fleet_event_log_writer_plugin()
332
330
  if fleet_log_plugin is not None:
@@ -447,13 +445,21 @@ def _try_load_public_keys_node_authentication(
447
445
  return node_public_keys
448
446
 
449
447
 
450
- def _try_obtain_control_auth_plugins(
451
- config_path: Path, verify_tls_cert: bool
448
+ def _load_control_auth_plugins(
449
+ config_path: Optional[str], verify_tls_cert: bool
452
450
  ) -> tuple[ControlAuthnPlugin, ControlAuthzPlugin]:
453
451
  """Obtain Control API authentication and authorization plugins."""
452
+ # Load NoOp plugins if no config path is provided
453
+ if config_path is None:
454
+ config_path = ""
455
+ config = {
456
+ "authentication": {AUTHN_TYPE_YAML_KEY: AuthnType.NOOP},
457
+ "authorization": {AUTHZ_TYPE_YAML_KEY: AuthzType.NOOP},
458
+ }
454
459
  # Load YAML file
455
- with config_path.open("r", encoding="utf-8") as file:
456
- config: dict[str, Any] = yaml.safe_load(file)
460
+ else:
461
+ with Path(config_path).open("r", encoding="utf-8") as file:
462
+ config = yaml.safe_load(file)
457
463
 
458
464
  def _load_plugin(
459
465
  section: str, yaml_key: str, loader: Callable[[], dict[str, type[P]]]
@@ -463,9 +469,7 @@ def _try_obtain_control_auth_plugins(
463
469
  try:
464
470
  plugins: dict[str, type[P]] = loader()
465
471
  plugin_cls: type[P] = plugins[auth_plugin_name]
466
- return plugin_cls(
467
- account_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
468
- )
472
+ return plugin_cls(Path(cast(str, config_path)), verify_tls_cert)
469
473
  except KeyError:
470
474
  if auth_plugin_name:
471
475
  sys.exit(
@@ -473,18 +477,15 @@ def _try_obtain_control_auth_plugins(
473
477
  f"Please provide a valid {section} type in the configuration."
474
478
  )
475
479
  sys.exit(f"No {section} type is provided in the configuration.")
476
- except NotImplementedError:
477
- sys.exit(f"No {section} plugins are currently supported.")
478
480
 
479
- # Warn deprecated authn_type key
480
- if "authn_type" in config["authentication"]:
481
+ # Warn deprecated auth_type key
482
+ if authn_type := config["authentication"].pop("auth_type", None):
481
483
  log(
482
484
  WARN,
483
- "The `authn_type` key in the authentication configuration is deprecated. "
485
+ "The `auth_type` key in the authentication configuration is deprecated. "
484
486
  "Use `%s` instead.",
485
487
  AUTHN_TYPE_YAML_KEY,
486
488
  )
487
- authn_type = config["authentication"].pop("authn_type")
488
489
  config["authentication"][AUTHN_TYPE_YAML_KEY] = authn_type
489
490
 
490
491
  # Load authentication plugin
@@ -78,10 +78,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
78
78
  request.heartbeat_interval,
79
79
  )
80
80
  log(DEBUG, "[Fleet.CreateNode] Request: %s", MessageToDict(request))
81
- response = message_handler.create_node(
82
- request=request,
83
- state=self.state_factory.state(),
84
- )
81
+ try:
82
+ response = message_handler.create_node(
83
+ request=request,
84
+ state=self.state_factory.state(),
85
+ )
86
+ except ValueError as e:
87
+ # Public key already in use
88
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
85
89
  log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
86
90
  log(DEBUG, "[Fleet.CreateNode] Response: %s", MessageToDict(response))
87
91
  return response
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  import datetime
19
- from typing import Any, Callable, Optional, cast
19
+ from typing import Any, Callable, cast
20
20
 
21
21
  import grpc
22
22
  from google.protobuf.message import Message as GrpcMessage
@@ -29,10 +29,7 @@ from flwr.common.constant import (
29
29
  TIMESTAMP_HEADER,
30
30
  TIMESTAMP_TOLERANCE,
31
31
  )
32
- from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
33
- CreateNodeRequest,
34
- CreateNodeResponse,
35
- )
32
+ from flwr.proto.fleet_pb2 import CreateNodeRequest # pylint: disable=E0611
36
33
  from flwr.server.superlink.linkstate import LinkStateFactory
37
34
  from flwr.supercore.primitives.asymmetric import bytes_to_public_key, verify_signature
38
35
 
@@ -50,7 +47,7 @@ def _unary_unary_rpc_terminator(
50
47
  return grpc.unary_unary_rpc_method_handler(terminate)
51
48
 
52
49
 
53
- class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
50
+ class NodeAuthServerInterceptor(grpc.ServerInterceptor): # type: ignore
54
51
  """Server interceptor for node authentication.
55
52
 
56
53
  Parameters
@@ -110,50 +107,34 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
110
107
  if not MIN_TIMESTAMP_DIFF < time_diff.total_seconds() < MAX_TIMESTAMP_DIFF:
111
108
  return _unary_unary_rpc_terminator("Invalid timestamp")
112
109
 
113
- # Continue the RPC call
114
- expected_node_id = state.get_node_id(node_pk_bytes)
115
- if not handler_call_details.method.endswith("CreateNode"):
116
- # All calls, except for `CreateNode`, must provide a public key that is
117
- # already mapped to a `node_id` (in `LinkState`)
118
- if expected_node_id is None:
119
- return _unary_unary_rpc_terminator("Invalid node ID")
120
- # One of the method handlers in
110
+ # Continue the RPC call: One of the method handlers in
121
111
  # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
122
112
  method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
123
- return self._wrap_method_handler(
124
- method_handler, expected_node_id, node_pk_bytes
125
- )
113
+ return self._wrap_method_handler(method_handler, node_pk_bytes)
126
114
 
127
115
  def _wrap_method_handler(
128
116
  self,
129
117
  method_handler: grpc.RpcMethodHandler,
130
- expected_node_id: Optional[int],
131
- node_public_key: bytes,
118
+ expected_public_key: bytes,
132
119
  ) -> grpc.RpcMethodHandler:
133
120
  def _generic_method_handler(
134
121
  request: GrpcMessage,
135
122
  context: grpc.ServicerContext,
136
123
  ) -> GrpcMessage:
137
- # Verify the node ID
138
- if not isinstance(request, CreateNodeRequest):
139
- try:
140
- if request.node.node_id != expected_node_id: # type: ignore
141
- raise ValueError
142
- except (AttributeError, ValueError):
143
- context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
124
+ # Retrieve the public key
125
+ if isinstance(request, CreateNodeRequest):
126
+ actual_public_key = request.public_key
127
+ else:
128
+ # Note: This function runs in a different thread
129
+ # than the `intercept_service` function.
130
+ actual_public_key = self.state_factory.state().get_node_public_key(
131
+ request.node.node_id # type: ignore
132
+ )
133
+ # Verify the public key
134
+ if actual_public_key != expected_public_key:
135
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
144
136
 
145
137
  response: GrpcMessage = method_handler.unary_unary(request, context)
146
-
147
- # Set the public key after a successful CreateNode request
148
- if isinstance(response, CreateNodeResponse):
149
- state = self.state_factory.state()
150
- try:
151
- state.set_node_public_key(response.node.node_id, node_public_key)
152
- except ValueError as e:
153
- # Remove newly created node if setting the public key fails
154
- state.delete_node(response.node.node_id)
155
- context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
156
-
157
138
  return response
158
139
 
159
140
  return grpc.unary_unary_rpc_method_handler(
@@ -18,7 +18,7 @@ from logging import ERROR
18
18
  from typing import Optional
19
19
 
20
20
  from flwr.common import Message, log
21
- from flwr.common.constant import Status
21
+ from flwr.common.constant import NOOP_FLWR_AID, Status
22
22
  from flwr.common.inflatable import UnexpectedObjectContentError
23
23
  from flwr.common.serde import (
24
24
  fab_to_proto,
@@ -70,7 +70,9 @@ def create_node(
70
70
  ) -> CreateNodeResponse:
71
71
  """."""
72
72
  # Create node
73
- node_id = state.create_node(heartbeat_interval=request.heartbeat_interval)
73
+ node_id = state.create_node(
74
+ NOOP_FLWR_AID, request.public_key, request.heartbeat_interval
75
+ )
74
76
  return CreateNodeResponse(node=Node(node_id=node_id))
75
77
 
76
78
 
@@ -81,7 +83,7 @@ def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeRespo
81
83
  return DeleteNodeResponse()
82
84
 
83
85
  # Update state
84
- state.delete_node(node_id=request.node.node_id)
86
+ state.delete_node(NOOP_FLWR_AID, node_id=request.node.node_id)
85
87
  return DeleteNodeResponse()
86
88
 
87
89
 
@@ -16,6 +16,7 @@
16
16
 
17
17
 
18
18
  import json
19
+ import secrets
19
20
  import threading
20
21
  import time
21
22
  import traceback
@@ -33,6 +34,7 @@ from flwr.clientapp.utils import get_load_client_app_fn
33
34
  from flwr.common import Message
34
35
  from flwr.common.constant import (
35
36
  HEARTBEAT_MAX_INTERVAL,
37
+ NOOP_FLWR_AID,
36
38
  NUM_PARTITIONS_KEY,
37
39
  PARTITION_ID_KEY,
38
40
  ErrorCode,
@@ -53,7 +55,14 @@ def _register_nodes(
53
55
  nodes_mapping: NodeToPartitionMapping = {}
54
56
  state = state_factory.state()
55
57
  for i in range(num_nodes):
56
- node_id = state.create_node(heartbeat_interval=HEARTBEAT_MAX_INTERVAL)
58
+ node_id = state.create_node(
59
+ # No node authentication in simulation;
60
+ # use NOOP_FLWR_AID as owner_aid and
61
+ # use random bytes as public key
62
+ NOOP_FLWR_AID,
63
+ secrets.token_bytes(32),
64
+ heartbeat_interval=HEARTBEAT_MAX_INTERVAL,
65
+ )
57
66
  nodes_mapping[node_id] = i
58
67
  log(DEBUG, "Registered %i nodes", len(nodes_mapping))
59
68
  return nodes_mapping