flwr-nightly 1.23.0.dev20251007__py3-none-any.whl → 1.23.0.dev20251009__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/auth_plugin/__init__.py +7 -3
- flwr/cli/log.py +2 -2
- flwr/cli/login/login.py +4 -13
- flwr/cli/ls.py +2 -2
- flwr/cli/pull.py +2 -2
- flwr/cli/run/run.py +2 -2
- flwr/cli/stop.py +2 -2
- flwr/cli/supernode/create.py +137 -11
- flwr/cli/supernode/delete.py +88 -10
- flwr/cli/supernode/ls.py +2 -2
- flwr/cli/utils.py +65 -55
- flwr/client/grpc_rere_client/connection.py +6 -4
- flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +2 -2
- flwr/client/rest_client/connection.py +7 -1
- flwr/common/constant.py +13 -0
- flwr/proto/control_pb2.py +1 -1
- flwr/proto/control_pb2.pyi +2 -2
- flwr/proto/fleet_pb2.py +22 -22
- flwr/proto/fleet_pb2.pyi +4 -1
- flwr/proto/node_pb2.py +2 -2
- flwr/proto/node_pb2.pyi +4 -1
- flwr/server/app.py +32 -31
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +8 -4
- flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +18 -37
- flwr/server/superlink/fleet/message_handler/message_handler.py +5 -3
- flwr/server/superlink/fleet/vce/vce_api.py +10 -1
- flwr/server/superlink/linkstate/in_memory_linkstate.py +52 -54
- flwr/server/superlink/linkstate/linkstate.py +20 -10
- flwr/server/superlink/linkstate/sqlite_linkstate.py +54 -61
- flwr/server/utils/validator.py +2 -3
- flwr/supercore/primitives/asymmetric.py +8 -0
- flwr/superlink/auth_plugin/__init__.py +29 -0
- flwr/superlink/servicer/control/control_grpc.py +9 -7
- flwr/superlink/servicer/control/control_servicer.py +89 -48
- {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/METADATA +1 -1
- {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/RECORD +38 -38
- {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.23.0.dev20251007.dist-info → flwr_nightly-1.23.0.dev20251009.dist-info}/entry_points.txt +0 -0
@@ -55,10 +55,10 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
55
55
|
from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
|
56
56
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
57
57
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
58
|
-
from flwr.supercore.primitives.asymmetric import generate_key_pairs
|
58
|
+
from flwr.supercore.primitives.asymmetric import generate_key_pairs, public_key_to_bytes
|
59
59
|
|
60
|
-
from .client_interceptor import AuthenticateClientInterceptor
|
61
60
|
from .grpc_adapter import GrpcAdapter
|
61
|
+
from .node_auth_client_interceptor import NodeAuthClientInterceptor
|
62
62
|
|
63
63
|
|
64
64
|
@contextmanager
|
@@ -138,8 +138,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
138
138
|
|
139
139
|
# Always configure auth interceptor, with either user-provided or generated keys
|
140
140
|
interceptors: Sequence[grpc.UnaryUnaryClientInterceptor] = [
|
141
|
-
|
141
|
+
NodeAuthClientInterceptor(*authentication_keys),
|
142
142
|
]
|
143
|
+
node_pk = public_key_to_bytes(authentication_keys[1])
|
143
144
|
channel = create_channel(
|
144
145
|
server_address=server_address,
|
145
146
|
insecure=insecure,
|
@@ -199,7 +200,8 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
199
200
|
"""Set create_node."""
|
200
201
|
# Call FleetAPI
|
201
202
|
create_node_request = CreateNodeRequest(
|
202
|
-
|
203
|
+
public_key=node_pk,
|
204
|
+
heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
|
203
205
|
)
|
204
206
|
create_node_response = stub.CreateNode(request=create_node_request)
|
205
207
|
|
@@ -26,8 +26,8 @@ from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_
|
|
26
26
|
from flwr.supercore.primitives.asymmetric import public_key_to_bytes, sign_message
|
27
27
|
|
28
28
|
|
29
|
-
class
|
30
|
-
"""Client interceptor for
|
29
|
+
class NodeAuthClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
|
30
|
+
"""Client interceptor for node authentication."""
|
31
31
|
|
32
32
|
def __init__(
|
33
33
|
self,
|
@@ -15,6 +15,7 @@
|
|
15
15
|
"""Contextmanager for a REST request-response channel to the Flower server."""
|
16
16
|
|
17
17
|
|
18
|
+
import secrets
|
18
19
|
from collections.abc import Iterator
|
19
20
|
from contextlib import contextmanager
|
20
21
|
from logging import ERROR, WARN
|
@@ -292,7 +293,12 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
292
293
|
|
293
294
|
def create_node() -> Optional[int]:
|
294
295
|
"""Set create_node."""
|
295
|
-
req = CreateNodeRequest(
|
296
|
+
req = CreateNodeRequest(
|
297
|
+
# REST does not support node authentication;
|
298
|
+
# random bytes are used instead
|
299
|
+
public_key=secrets.token_bytes(32),
|
300
|
+
heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL,
|
301
|
+
)
|
296
302
|
|
297
303
|
# Send the request
|
298
304
|
res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
|
flwr/common/constant.py
CHANGED
@@ -157,6 +157,9 @@ RUN_ID_NOT_FOUND_MESSAGE = "Run ID not found"
|
|
157
157
|
NO_ACCOUNT_AUTH_MESSAGE = "ControlServicer initialized without account authentication"
|
158
158
|
NO_ARTIFACT_PROVIDER_MESSAGE = "ControlServicer initialized without artifact provider"
|
159
159
|
PULL_UNFINISHED_RUN_MESSAGE = "Cannot pull artifacts for an unfinished run"
|
160
|
+
PUBLIC_KEY_ALREADY_IN_USE_MESSAGE = "Public key already in use"
|
161
|
+
PUBLIC_KEY_NOT_VALID = "The provided public key is not valid"
|
162
|
+
NODE_NOT_FOUND_MESSAGE = "Node ID not found for account"
|
160
163
|
|
161
164
|
|
162
165
|
class MessageType:
|
@@ -256,6 +259,16 @@ class AuthnType:
|
|
256
259
|
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
257
260
|
|
258
261
|
|
262
|
+
class AuthzType:
|
263
|
+
"""Account authorization types."""
|
264
|
+
|
265
|
+
NOOP = "noop"
|
266
|
+
|
267
|
+
def __new__(cls) -> AuthzType:
|
268
|
+
"""Prevent instantiation."""
|
269
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
270
|
+
|
271
|
+
|
259
272
|
class EventLogWriterType:
|
260
273
|
"""Event log writer types."""
|
261
274
|
|
flwr/proto/control_pb2.py
CHANGED
@@ -19,7 +19,7 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
|
|
19
19
|
from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
|
20
20
|
|
21
21
|
|
22
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x15\x66lwr/proto/node.proto\"\xfa\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x18.flwr.proto.ConfigRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8b\x01\n\x17GetLoginDetailsResponse\x12\x12\n\nauthn_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"&\n\x14PullArtifactsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"1\n\x15PullArtifactsResponse\x12\x10\n\x03url\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x06\n\x04_url\"*\n\x14\x43reateNodeCliRequest\x12\x12\n\npublic_key\x18\x01 \x01(\
|
22
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x15\x66lwr/proto/node.proto\"\xfa\x01\n\x0fStartRunRequest\x12\x1c\n\x03\x66\x61\x62\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Fab\x12H\n\x0foverride_config\x18\x02 \x03(\x0b\x32/.flwr.proto.StartRunRequest.OverrideConfigEntry\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x03 \x01(\x0b\x32\x18.flwr.proto.ConfigRecord\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"2\n\x10StartRunResponse\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"<\n\x11StreamLogsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x17\n\x0f\x61\x66ter_timestamp\x18\x02 \x01(\x01\"B\n\x12StreamLogsResponse\x12\x12\n\nlog_output\x18\x01 \x01(\t\x12\x18\n\x10latest_timestamp\x18\x02 \x01(\x01\"1\n\x0fListRunsRequest\x12\x13\n\x06run_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\t\n\x07_run_id\"\x9d\x01\n\x10ListRunsResponse\x12;\n\x08run_dict\x18\x01 \x03(\x0b\x32).flwr.proto.ListRunsResponse.RunDictEntry\x12\x0b\n\x03now\x18\x02 \x01(\t\x1a?\n\x0cRunDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12\x1e\n\x05value\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run:\x02\x38\x01\"\x18\n\x16GetLoginDetailsRequest\"\x8b\x01\n\x17GetLoginDetailsResponse\x12\x12\n\nauthn_type\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65vice_code\x18\x02 \x01(\t\x12!\n\x19verification_uri_complete\x18\x03 \x01(\t\x12\x12\n\nexpires_in\x18\x04 \x01(\x03\x12\x10\n\x08interval\x18\x05 \x01(\x03\"+\n\x14GetAuthTokensRequest\x12\x13\n\x0b\x64\x65vice_code\x18\x01 \x01(\t\"D\n\x15GetAuthTokensResponse\x12\x14\n\x0c\x61\x63\x63\x65ss_token\x18\x01 \x01(\t\x12\x15\n\rrefresh_token\x18\x02 \x01(\t\" \n\x0eStopRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\"\n\x0fStopRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\"&\n\x14PullArtifactsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"1\n\x15PullArtifactsResponse\x12\x10\n\x03url\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x06\n\x04_url\"*\n\x14\x43reateNodeCliRequest\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\"9\n\x15\x43reateNodeCliResponse\x12\x14\n\x07node_id\x18\x01 \x01(\x04H\x00\x88\x01\x01\x42\n\n\x08_node_id\"\'\n\x14\x44\x65leteNodeCliRequest\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\x17\n\x15\x44\x65leteNodeCliResponse\"\x15\n\x13ListNodesCliRequest\"M\n\x14ListNodesCliResponse\x12(\n\nnodes_info\x18\x01 \x03(\x0b\x32\x14.flwr.proto.NodeInfo\x12\x0b\n\x03now\x18\x02 \x01(\t2\xc5\x06\n\x07\x43ontrol\x12G\n\x08StartRun\x12\x1b.flwr.proto.StartRunRequest\x1a\x1c.flwr.proto.StartRunResponse\"\x00\x12\x44\n\x07StopRun\x12\x1a.flwr.proto.StopRunRequest\x1a\x1b.flwr.proto.StopRunResponse\"\x00\x12O\n\nStreamLogs\x12\x1d.flwr.proto.StreamLogsRequest\x1a\x1e.flwr.proto.StreamLogsResponse\"\x00\x30\x01\x12G\n\x08ListRuns\x12\x1b.flwr.proto.ListRunsRequest\x1a\x1c.flwr.proto.ListRunsResponse\"\x00\x12\\\n\x0fGetLoginDetails\x12\".flwr.proto.GetLoginDetailsRequest\x1a#.flwr.proto.GetLoginDetailsResponse\"\x00\x12V\n\rGetAuthTokens\x12 .flwr.proto.GetAuthTokensRequest\x1a!.flwr.proto.GetAuthTokensResponse\"\x00\x12V\n\rPullArtifacts\x12 .flwr.proto.PullArtifactsRequest\x1a!.flwr.proto.PullArtifactsResponse\"\x00\x12V\n\rCreateNodeCli\x12 .flwr.proto.CreateNodeCliRequest\x1a!.flwr.proto.CreateNodeCliResponse\"\x00\x12V\n\rDeleteNodeCli\x12 .flwr.proto.DeleteNodeCliRequest\x1a!.flwr.proto.DeleteNodeCliResponse\"\x00\x12S\n\x0cListNodesCli\x12\x1f.flwr.proto.ListNodesCliRequest\x1a .flwr.proto.ListNodesCliResponse\"\x00\x62\x06proto3')
|
23
23
|
|
24
24
|
_globals = globals()
|
25
25
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
flwr/proto/control_pb2.pyi
CHANGED
@@ -239,10 +239,10 @@ global___PullArtifactsResponse = PullArtifactsResponse
|
|
239
239
|
class CreateNodeCliRequest(google.protobuf.message.Message):
|
240
240
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
241
241
|
PUBLIC_KEY_FIELD_NUMBER: builtins.int
|
242
|
-
public_key:
|
242
|
+
public_key: builtins.bytes
|
243
243
|
def __init__(self,
|
244
244
|
*,
|
245
|
-
public_key:
|
245
|
+
public_key: builtins.bytes = ...,
|
246
246
|
) -> None: ...
|
247
247
|
def ClearField(self, field_name: typing_extensions.Literal["public_key",b"public_key"]) -> None: ...
|
248
248
|
global___CreateNodeCliRequest = CreateNodeCliRequest
|
flwr/proto/fleet_pb2.py
CHANGED
@@ -19,7 +19,7 @@ from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
|
|
19
19
|
from flwr.proto import message_pb2 as flwr_dot_proto_dot_message__pb2
|
20
20
|
|
21
21
|
|
22
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x18\x66lwr/proto/message.proto\"
|
22
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66lwr/proto/fleet.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/heartbeat.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x18\x66lwr/proto/message.proto\"C\n\x11\x43reateNodeRequest\x12\x12\n\npublic_key\x18\x01 \x01(\x0c\x12\x1a\n\x12heartbeat_interval\x18\x02 \x01(\x01\"4\n\x12\x43reateNodeResponse\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"3\n\x11\x44\x65leteNodeRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\"\x14\n\x12\x44\x65leteNodeResponse\"J\n\x13PullMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x13\n\x0bmessage_ids\x18\x02 \x03(\t\"\xa2\x01\n\x14PullMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\x97\x01\n\x13PushMessagesRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12*\n\rmessages_list\x18\x02 \x03(\x0b\x32\x13.flwr.proto.Message\x12\x34\n\x14message_object_trees\x18\x03 \x03(\x0b\x32\x16.flwr.proto.ObjectTree\"\xc9\x01\n\x14PushMessagesResponse\x12(\n\treconnect\x18\x01 \x01(\x0b\x32\x15.flwr.proto.Reconnect\x12>\n\x07results\x18\x02 \x03(\x0b\x32-.flwr.proto.PushMessagesResponse.ResultsEntry\x12\x17\n\x0fobjects_to_push\x18\x03 \x03(\t\x1a.\n\x0cResultsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\x1e\n\tReconnect\x12\x11\n\treconnect\x18\x01 \x01(\x04\x32\xca\x06\n\x05\x46leet\x12M\n\nCreateNode\x12\x1d.flwr.proto.CreateNodeRequest\x1a\x1e.flwr.proto.CreateNodeResponse\"\x00\x12M\n\nDeleteNode\x12\x1d.flwr.proto.DeleteNodeRequest\x1a\x1e.flwr.proto.DeleteNodeResponse\"\x00\x12\x62\n\x11SendNodeHeartbeat\x12$.flwr.proto.SendNodeHeartbeatRequest\x1a%.flwr.proto.SendNodeHeartbeatResponse\"\x00\x12S\n\x0cPullMessages\x12\x1f.flwr.proto.PullMessagesRequest\x1a .flwr.proto.PullMessagesResponse\"\x00\x12S\n\x0cPushMessages\x12\x1f.flwr.proto.PushMessagesRequest\x1a .flwr.proto.PushMessagesResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12M\n\nPushObject\x12\x1d.flwr.proto.PushObjectRequest\x1a\x1e.flwr.proto.PushObjectResponse\"\x00\x12M\n\nPullObject\x12\x1d.flwr.proto.PullObjectRequest\x1a\x1e.flwr.proto.PullObjectResponse\"\x00\x12q\n\x16\x43onfirmMessageReceived\x12).flwr.proto.ConfirmMessageReceivedRequest\x1a*.flwr.proto.ConfirmMessageReceivedResponse\"\x00\x62\x06proto3')
|
23
23
|
|
24
24
|
_globals = globals()
|
25
25
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
@@ -29,25 +29,25 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
29
29
|
_globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._options = None
|
30
30
|
_globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_options = b'8\001'
|
31
31
|
_globals['_CREATENODEREQUEST']._serialized_start=159
|
32
|
-
_globals['_CREATENODEREQUEST']._serialized_end=
|
33
|
-
_globals['_CREATENODERESPONSE']._serialized_start=
|
34
|
-
_globals['_CREATENODERESPONSE']._serialized_end=
|
35
|
-
_globals['_DELETENODEREQUEST']._serialized_start=
|
36
|
-
_globals['_DELETENODEREQUEST']._serialized_end=
|
37
|
-
_globals['_DELETENODERESPONSE']._serialized_start=
|
38
|
-
_globals['_DELETENODERESPONSE']._serialized_end=
|
39
|
-
_globals['_PULLMESSAGESREQUEST']._serialized_start=
|
40
|
-
_globals['_PULLMESSAGESREQUEST']._serialized_end=
|
41
|
-
_globals['_PULLMESSAGESRESPONSE']._serialized_start=
|
42
|
-
_globals['_PULLMESSAGESRESPONSE']._serialized_end=
|
43
|
-
_globals['_PUSHMESSAGESREQUEST']._serialized_start=
|
44
|
-
_globals['_PUSHMESSAGESREQUEST']._serialized_end=
|
45
|
-
_globals['_PUSHMESSAGESRESPONSE']._serialized_start=
|
46
|
-
_globals['_PUSHMESSAGESRESPONSE']._serialized_end=
|
47
|
-
_globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_start=
|
48
|
-
_globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_end=
|
49
|
-
_globals['_RECONNECT']._serialized_start=
|
50
|
-
_globals['_RECONNECT']._serialized_end=
|
51
|
-
_globals['_FLEET']._serialized_start=
|
52
|
-
_globals['_FLEET']._serialized_end=
|
32
|
+
_globals['_CREATENODEREQUEST']._serialized_end=226
|
33
|
+
_globals['_CREATENODERESPONSE']._serialized_start=228
|
34
|
+
_globals['_CREATENODERESPONSE']._serialized_end=280
|
35
|
+
_globals['_DELETENODEREQUEST']._serialized_start=282
|
36
|
+
_globals['_DELETENODEREQUEST']._serialized_end=333
|
37
|
+
_globals['_DELETENODERESPONSE']._serialized_start=335
|
38
|
+
_globals['_DELETENODERESPONSE']._serialized_end=355
|
39
|
+
_globals['_PULLMESSAGESREQUEST']._serialized_start=357
|
40
|
+
_globals['_PULLMESSAGESREQUEST']._serialized_end=431
|
41
|
+
_globals['_PULLMESSAGESRESPONSE']._serialized_start=434
|
42
|
+
_globals['_PULLMESSAGESRESPONSE']._serialized_end=596
|
43
|
+
_globals['_PUSHMESSAGESREQUEST']._serialized_start=599
|
44
|
+
_globals['_PUSHMESSAGESREQUEST']._serialized_end=750
|
45
|
+
_globals['_PUSHMESSAGESRESPONSE']._serialized_start=753
|
46
|
+
_globals['_PUSHMESSAGESRESPONSE']._serialized_end=954
|
47
|
+
_globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_start=908
|
48
|
+
_globals['_PUSHMESSAGESRESPONSE_RESULTSENTRY']._serialized_end=954
|
49
|
+
_globals['_RECONNECT']._serialized_start=956
|
50
|
+
_globals['_RECONNECT']._serialized_end=986
|
51
|
+
_globals['_FLEET']._serialized_start=989
|
52
|
+
_globals['_FLEET']._serialized_end=1831
|
53
53
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/fleet_pb2.pyi
CHANGED
@@ -16,13 +16,16 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
|
16
16
|
class CreateNodeRequest(google.protobuf.message.Message):
|
17
17
|
"""CreateNode messages"""
|
18
18
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
19
|
+
PUBLIC_KEY_FIELD_NUMBER: builtins.int
|
19
20
|
HEARTBEAT_INTERVAL_FIELD_NUMBER: builtins.int
|
21
|
+
public_key: builtins.bytes
|
20
22
|
heartbeat_interval: builtins.float
|
21
23
|
def __init__(self,
|
22
24
|
*,
|
25
|
+
public_key: builtins.bytes = ...,
|
23
26
|
heartbeat_interval: builtins.float = ...,
|
24
27
|
) -> None: ...
|
25
|
-
def ClearField(self, field_name: typing_extensions.Literal["heartbeat_interval",b"heartbeat_interval"]) -> None: ...
|
28
|
+
def ClearField(self, field_name: typing_extensions.Literal["heartbeat_interval",b"heartbeat_interval","public_key",b"public_key"]) -> None: ...
|
26
29
|
global___CreateNodeRequest = CreateNodeRequest
|
27
30
|
|
28
31
|
class CreateNodeResponse(google.protobuf.message.Message):
|
flwr/proto/node_pb2.py
CHANGED
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
|
|
14
14
|
|
15
15
|
|
16
16
|
|
17
|
-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/node.proto\x12\nflwr.proto\"\x17\n\x04Node\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\
|
17
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/node.proto\x12\nflwr.proto\"\x17\n\x04Node\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\"\xe4\x01\n\x08NodeInfo\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\x12\x11\n\towner_aid\x18\x02 \x01(\t\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x12\n\ncreated_at\x18\x04 \x01(\t\x12\x19\n\x11last_activated_at\x18\x05 \x01(\t\x12\x1b\n\x13last_deactivated_at\x18\x06 \x01(\t\x12\x12\n\ndeleted_at\x18\x07 \x01(\t\x12\x14\n\x0conline_until\x18\x08 \x01(\x01\x12\x1a\n\x12heartbeat_interval\x18\t \x01(\x01\x12\x12\n\npublic_key\x18\n \x01(\x0c\x62\x06proto3')
|
18
18
|
|
19
19
|
_globals = globals()
|
20
20
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
@@ -24,5 +24,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
|
|
24
24
|
_globals['_NODE']._serialized_start=37
|
25
25
|
_globals['_NODE']._serialized_end=60
|
26
26
|
_globals['_NODEINFO']._serialized_start=63
|
27
|
-
_globals['_NODEINFO']._serialized_end=
|
27
|
+
_globals['_NODEINFO']._serialized_end=291
|
28
28
|
# @@protoc_insertion_point(module_scope)
|
flwr/proto/node_pb2.pyi
CHANGED
@@ -32,6 +32,7 @@ class NodeInfo(google.protobuf.message.Message):
|
|
32
32
|
DELETED_AT_FIELD_NUMBER: builtins.int
|
33
33
|
ONLINE_UNTIL_FIELD_NUMBER: builtins.int
|
34
34
|
HEARTBEAT_INTERVAL_FIELD_NUMBER: builtins.int
|
35
|
+
PUBLIC_KEY_FIELD_NUMBER: builtins.int
|
35
36
|
node_id: builtins.int
|
36
37
|
owner_aid: typing.Text
|
37
38
|
status: typing.Text
|
@@ -41,6 +42,7 @@ class NodeInfo(google.protobuf.message.Message):
|
|
41
42
|
deleted_at: typing.Text
|
42
43
|
online_until: builtins.float
|
43
44
|
heartbeat_interval: builtins.float
|
45
|
+
public_key: builtins.bytes
|
44
46
|
def __init__(self,
|
45
47
|
*,
|
46
48
|
node_id: builtins.int = ...,
|
@@ -52,6 +54,7 @@ class NodeInfo(google.protobuf.message.Message):
|
|
52
54
|
deleted_at: typing.Text = ...,
|
53
55
|
online_until: builtins.float = ...,
|
54
56
|
heartbeat_interval: builtins.float = ...,
|
57
|
+
public_key: builtins.bytes = ...,
|
55
58
|
) -> None: ...
|
56
|
-
def ClearField(self, field_name: typing_extensions.Literal["created_at",b"created_at","deleted_at",b"deleted_at","heartbeat_interval",b"heartbeat_interval","last_activated_at",b"last_activated_at","last_deactivated_at",b"last_deactivated_at","node_id",b"node_id","online_until",b"online_until","owner_aid",b"owner_aid","status",b"status"]) -> None: ...
|
59
|
+
def ClearField(self, field_name: typing_extensions.Literal["created_at",b"created_at","deleted_at",b"deleted_at","heartbeat_interval",b"heartbeat_interval","last_activated_at",b"last_activated_at","last_deactivated_at",b"last_deactivated_at","node_id",b"node_id","online_until",b"online_until","owner_aid",b"owner_aid","public_key",b"public_key","status",b"status"]) -> None: ...
|
57
60
|
global___NodeInfo = NodeInfo
|
flwr/server/app.py
CHANGED
@@ -26,7 +26,7 @@ from collections.abc import Sequence
|
|
26
26
|
from logging import DEBUG, INFO, WARN
|
27
27
|
from pathlib import Path
|
28
28
|
from time import sleep
|
29
|
-
from typing import
|
29
|
+
from typing import Callable, Optional, TypeVar, cast
|
30
30
|
|
31
31
|
import grpc
|
32
32
|
import yaml
|
@@ -52,6 +52,8 @@ from flwr.common.constant import (
|
|
52
52
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
53
53
|
TRANSPORT_TYPE_GRPC_RERE,
|
54
54
|
TRANSPORT_TYPE_REST,
|
55
|
+
AuthnType,
|
56
|
+
AuthzType,
|
55
57
|
EventLogWriterType,
|
56
58
|
ExecPluginType,
|
57
59
|
)
|
@@ -69,12 +71,19 @@ from flwr.supercore.grpc_health import add_args_health, run_health_server_grpc_n
|
|
69
71
|
from flwr.supercore.object_store import ObjectStoreFactory
|
70
72
|
from flwr.supercore.primitives.asymmetric import public_key_to_bytes
|
71
73
|
from flwr.superlink.artifact_provider import ArtifactProvider
|
72
|
-
from flwr.superlink.auth_plugin import
|
74
|
+
from flwr.superlink.auth_plugin import (
|
75
|
+
ControlAuthnPlugin,
|
76
|
+
ControlAuthzPlugin,
|
77
|
+
get_control_authn_plugins,
|
78
|
+
get_control_authz_plugins,
|
79
|
+
)
|
73
80
|
from flwr.superlink.servicer.control import run_control_api_grpc
|
74
81
|
|
75
82
|
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
|
76
83
|
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
77
|
-
from .superlink.fleet.grpc_rere.
|
84
|
+
from .superlink.fleet.grpc_rere.node_auth_server_interceptor import (
|
85
|
+
NodeAuthServerInterceptor,
|
86
|
+
)
|
78
87
|
from .superlink.linkstate import LinkStateFactory
|
79
88
|
from .superlink.serverappio.serverappio_grpc import run_serverappio_api_grpc
|
80
89
|
from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc
|
@@ -87,8 +96,6 @@ P = TypeVar("P", ControlAuthnPlugin, ControlAuthzPlugin)
|
|
87
96
|
try:
|
88
97
|
from flwr.ee import (
|
89
98
|
add_ee_args_superlink,
|
90
|
-
get_control_authn_plugins,
|
91
|
-
get_control_authz_plugins,
|
92
99
|
get_control_event_log_writer_plugins,
|
93
100
|
get_ee_artifact_provider,
|
94
101
|
get_fleet_event_log_writer_plugins,
|
@@ -99,14 +106,6 @@ except ImportError:
|
|
99
106
|
def add_ee_args_superlink(parser: argparse.ArgumentParser) -> None:
|
100
107
|
"""Add EE-specific arguments to the parser."""
|
101
108
|
|
102
|
-
def get_control_authn_plugins() -> dict[str, type[ControlAuthnPlugin]]:
|
103
|
-
"""Return all Control API authentication plugins."""
|
104
|
-
raise NotImplementedError("No authentication plugins are currently supported.")
|
105
|
-
|
106
|
-
def get_control_authz_plugins() -> dict[str, type[ControlAuthzPlugin]]:
|
107
|
-
"""Return all Control API authorization plugins."""
|
108
|
-
raise NotImplementedError("No authorization plugins are currently supported.")
|
109
|
-
|
110
109
|
def get_control_event_log_writer_plugins() -> dict[str, type[EventLogWriterPlugin]]:
|
111
110
|
"""Return all Control API event log writer plugins."""
|
112
111
|
raise NotImplementedError(
|
@@ -202,10 +201,9 @@ def run_superlink() -> None:
|
|
202
201
|
"future release. Please use `--account-auth-config` instead.",
|
203
202
|
)
|
204
203
|
args.account_auth_config = cfg_path
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
)
|
204
|
+
cfg_path = getattr(args, "account_auth_config", None)
|
205
|
+
authn_plugin, authz_plugin = _load_control_auth_plugins(cfg_path, verify_tls_cert)
|
206
|
+
if cfg_path is not None:
|
209
207
|
# Enable event logging if the args.enable_event_log is True
|
210
208
|
if args.enable_event_log:
|
211
209
|
event_log_plugin = _try_obtain_control_event_log_writer_plugin()
|
@@ -326,7 +324,7 @@ def run_superlink() -> None:
|
|
326
324
|
else:
|
327
325
|
log(DEBUG, "Automatic node authentication enabled")
|
328
326
|
|
329
|
-
interceptors = [
|
327
|
+
interceptors = [NodeAuthServerInterceptor(state_factory, auto_auth)]
|
330
328
|
if getattr(args, "enable_event_log", None):
|
331
329
|
fleet_log_plugin = _try_obtain_fleet_event_log_writer_plugin()
|
332
330
|
if fleet_log_plugin is not None:
|
@@ -447,13 +445,21 @@ def _try_load_public_keys_node_authentication(
|
|
447
445
|
return node_public_keys
|
448
446
|
|
449
447
|
|
450
|
-
def
|
451
|
-
config_path:
|
448
|
+
def _load_control_auth_plugins(
|
449
|
+
config_path: Optional[str], verify_tls_cert: bool
|
452
450
|
) -> tuple[ControlAuthnPlugin, ControlAuthzPlugin]:
|
453
451
|
"""Obtain Control API authentication and authorization plugins."""
|
452
|
+
# Load NoOp plugins if no config path is provided
|
453
|
+
if config_path is None:
|
454
|
+
config_path = ""
|
455
|
+
config = {
|
456
|
+
"authentication": {AUTHN_TYPE_YAML_KEY: AuthnType.NOOP},
|
457
|
+
"authorization": {AUTHZ_TYPE_YAML_KEY: AuthzType.NOOP},
|
458
|
+
}
|
454
459
|
# Load YAML file
|
455
|
-
|
456
|
-
|
460
|
+
else:
|
461
|
+
with Path(config_path).open("r", encoding="utf-8") as file:
|
462
|
+
config = yaml.safe_load(file)
|
457
463
|
|
458
464
|
def _load_plugin(
|
459
465
|
section: str, yaml_key: str, loader: Callable[[], dict[str, type[P]]]
|
@@ -463,9 +469,7 @@ def _try_obtain_control_auth_plugins(
|
|
463
469
|
try:
|
464
470
|
plugins: dict[str, type[P]] = loader()
|
465
471
|
plugin_cls: type[P] = plugins[auth_plugin_name]
|
466
|
-
return plugin_cls(
|
467
|
-
account_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
|
468
|
-
)
|
472
|
+
return plugin_cls(Path(cast(str, config_path)), verify_tls_cert)
|
469
473
|
except KeyError:
|
470
474
|
if auth_plugin_name:
|
471
475
|
sys.exit(
|
@@ -473,18 +477,15 @@ def _try_obtain_control_auth_plugins(
|
|
473
477
|
f"Please provide a valid {section} type in the configuration."
|
474
478
|
)
|
475
479
|
sys.exit(f"No {section} type is provided in the configuration.")
|
476
|
-
except NotImplementedError:
|
477
|
-
sys.exit(f"No {section} plugins are currently supported.")
|
478
480
|
|
479
|
-
# Warn deprecated
|
480
|
-
if
|
481
|
+
# Warn deprecated auth_type key
|
482
|
+
if authn_type := config["authentication"].pop("auth_type", None):
|
481
483
|
log(
|
482
484
|
WARN,
|
483
|
-
"The `
|
485
|
+
"The `auth_type` key in the authentication configuration is deprecated. "
|
484
486
|
"Use `%s` instead.",
|
485
487
|
AUTHN_TYPE_YAML_KEY,
|
486
488
|
)
|
487
|
-
authn_type = config["authentication"].pop("authn_type")
|
488
489
|
config["authentication"][AUTHN_TYPE_YAML_KEY] = authn_type
|
489
490
|
|
490
491
|
# Load authentication plugin
|
@@ -78,10 +78,14 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
78
78
|
request.heartbeat_interval,
|
79
79
|
)
|
80
80
|
log(DEBUG, "[Fleet.CreateNode] Request: %s", MessageToDict(request))
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
81
|
+
try:
|
82
|
+
response = message_handler.create_node(
|
83
|
+
request=request,
|
84
|
+
state=self.state_factory.state(),
|
85
|
+
)
|
86
|
+
except ValueError as e:
|
87
|
+
# Public key already in use
|
88
|
+
context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
|
85
89
|
log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
|
86
90
|
log(DEBUG, "[Fleet.CreateNode] Response: %s", MessageToDict(response))
|
87
91
|
return response
|
flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py}
RENAMED
@@ -16,7 +16,7 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import datetime
|
19
|
-
from typing import Any, Callable,
|
19
|
+
from typing import Any, Callable, cast
|
20
20
|
|
21
21
|
import grpc
|
22
22
|
from google.protobuf.message import Message as GrpcMessage
|
@@ -29,10 +29,7 @@ from flwr.common.constant import (
|
|
29
29
|
TIMESTAMP_HEADER,
|
30
30
|
TIMESTAMP_TOLERANCE,
|
31
31
|
)
|
32
|
-
from flwr.proto.fleet_pb2 import
|
33
|
-
CreateNodeRequest,
|
34
|
-
CreateNodeResponse,
|
35
|
-
)
|
32
|
+
from flwr.proto.fleet_pb2 import CreateNodeRequest # pylint: disable=E0611
|
36
33
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
37
34
|
from flwr.supercore.primitives.asymmetric import bytes_to_public_key, verify_signature
|
38
35
|
|
@@ -50,7 +47,7 @@ def _unary_unary_rpc_terminator(
|
|
50
47
|
return grpc.unary_unary_rpc_method_handler(terminate)
|
51
48
|
|
52
49
|
|
53
|
-
class
|
50
|
+
class NodeAuthServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
54
51
|
"""Server interceptor for node authentication.
|
55
52
|
|
56
53
|
Parameters
|
@@ -110,50 +107,34 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
110
107
|
if not MIN_TIMESTAMP_DIFF < time_diff.total_seconds() < MAX_TIMESTAMP_DIFF:
|
111
108
|
return _unary_unary_rpc_terminator("Invalid timestamp")
|
112
109
|
|
113
|
-
# Continue the RPC call
|
114
|
-
expected_node_id = state.get_node_id(node_pk_bytes)
|
115
|
-
if not handler_call_details.method.endswith("CreateNode"):
|
116
|
-
# All calls, except for `CreateNode`, must provide a public key that is
|
117
|
-
# already mapped to a `node_id` (in `LinkState`)
|
118
|
-
if expected_node_id is None:
|
119
|
-
return _unary_unary_rpc_terminator("Invalid node ID")
|
120
|
-
# One of the method handlers in
|
110
|
+
# Continue the RPC call: One of the method handlers in
|
121
111
|
# `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
|
122
112
|
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
|
123
|
-
return self._wrap_method_handler(
|
124
|
-
method_handler, expected_node_id, node_pk_bytes
|
125
|
-
)
|
113
|
+
return self._wrap_method_handler(method_handler, node_pk_bytes)
|
126
114
|
|
127
115
|
def _wrap_method_handler(
|
128
116
|
self,
|
129
117
|
method_handler: grpc.RpcMethodHandler,
|
130
|
-
|
131
|
-
node_public_key: bytes,
|
118
|
+
expected_public_key: bytes,
|
132
119
|
) -> grpc.RpcMethodHandler:
|
133
120
|
def _generic_method_handler(
|
134
121
|
request: GrpcMessage,
|
135
122
|
context: grpc.ServicerContext,
|
136
123
|
) -> GrpcMessage:
|
137
|
-
#
|
138
|
-
if
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
124
|
+
# Retrieve the public key
|
125
|
+
if isinstance(request, CreateNodeRequest):
|
126
|
+
actual_public_key = request.public_key
|
127
|
+
else:
|
128
|
+
# Note: This function runs in a different thread
|
129
|
+
# than the `intercept_service` function.
|
130
|
+
actual_public_key = self.state_factory.state().get_node_public_key(
|
131
|
+
request.node.node_id # type: ignore
|
132
|
+
)
|
133
|
+
# Verify the public key
|
134
|
+
if actual_public_key != expected_public_key:
|
135
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
|
144
136
|
|
145
137
|
response: GrpcMessage = method_handler.unary_unary(request, context)
|
146
|
-
|
147
|
-
# Set the public key after a successful CreateNode request
|
148
|
-
if isinstance(response, CreateNodeResponse):
|
149
|
-
state = self.state_factory.state()
|
150
|
-
try:
|
151
|
-
state.set_node_public_key(response.node.node_id, node_public_key)
|
152
|
-
except ValueError as e:
|
153
|
-
# Remove newly created node if setting the public key fails
|
154
|
-
state.delete_node(response.node.node_id)
|
155
|
-
context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
|
156
|
-
|
157
138
|
return response
|
158
139
|
|
159
140
|
return grpc.unary_unary_rpc_method_handler(
|
@@ -18,7 +18,7 @@ from logging import ERROR
|
|
18
18
|
from typing import Optional
|
19
19
|
|
20
20
|
from flwr.common import Message, log
|
21
|
-
from flwr.common.constant import Status
|
21
|
+
from flwr.common.constant import NOOP_FLWR_AID, Status
|
22
22
|
from flwr.common.inflatable import UnexpectedObjectContentError
|
23
23
|
from flwr.common.serde import (
|
24
24
|
fab_to_proto,
|
@@ -70,7 +70,9 @@ def create_node(
|
|
70
70
|
) -> CreateNodeResponse:
|
71
71
|
"""."""
|
72
72
|
# Create node
|
73
|
-
node_id = state.create_node(
|
73
|
+
node_id = state.create_node(
|
74
|
+
NOOP_FLWR_AID, request.public_key, request.heartbeat_interval
|
75
|
+
)
|
74
76
|
return CreateNodeResponse(node=Node(node_id=node_id))
|
75
77
|
|
76
78
|
|
@@ -81,7 +83,7 @@ def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeRespo
|
|
81
83
|
return DeleteNodeResponse()
|
82
84
|
|
83
85
|
# Update state
|
84
|
-
state.delete_node(node_id=request.node.node_id)
|
86
|
+
state.delete_node(NOOP_FLWR_AID, node_id=request.node.node_id)
|
85
87
|
return DeleteNodeResponse()
|
86
88
|
|
87
89
|
|
@@ -16,6 +16,7 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import json
|
19
|
+
import secrets
|
19
20
|
import threading
|
20
21
|
import time
|
21
22
|
import traceback
|
@@ -33,6 +34,7 @@ from flwr.clientapp.utils import get_load_client_app_fn
|
|
33
34
|
from flwr.common import Message
|
34
35
|
from flwr.common.constant import (
|
35
36
|
HEARTBEAT_MAX_INTERVAL,
|
37
|
+
NOOP_FLWR_AID,
|
36
38
|
NUM_PARTITIONS_KEY,
|
37
39
|
PARTITION_ID_KEY,
|
38
40
|
ErrorCode,
|
@@ -53,7 +55,14 @@ def _register_nodes(
|
|
53
55
|
nodes_mapping: NodeToPartitionMapping = {}
|
54
56
|
state = state_factory.state()
|
55
57
|
for i in range(num_nodes):
|
56
|
-
node_id = state.create_node(
|
58
|
+
node_id = state.create_node(
|
59
|
+
# No node authentication in simulation;
|
60
|
+
# use NOOP_FLWR_AID as owner_aid and
|
61
|
+
# use random bytes as public key
|
62
|
+
NOOP_FLWR_AID,
|
63
|
+
secrets.token_bytes(32),
|
64
|
+
heartbeat_interval=HEARTBEAT_MAX_INTERVAL,
|
65
|
+
)
|
57
66
|
nodes_mapping[node_id] = i
|
58
67
|
log(DEBUG, "Registered %i nodes", len(nodes_mapping))
|
59
68
|
return nodes_mapping
|