flwr 1.14.0__py3-none-any.whl → 1.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/auth_plugin/__init__.py +31 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
- flwr/cli/cli_user_auth_interceptor.py +6 -2
- flwr/cli/config_utils.py +24 -147
- flwr/cli/constant.py +27 -0
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +18 -3
- flwr/cli/login/login.py +43 -8
- flwr/cli/ls.py +14 -5
- flwr/cli/new/templates/app/README.md.tpl +3 -2
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/run/run.py +21 -11
- flwr/cli/stop.py +13 -4
- flwr/cli/utils.py +54 -40
- flwr/client/app.py +36 -48
- flwr/client/clientapp/app.py +19 -25
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/grpc_client/connection.py +1 -12
- flwr/client/grpc_rere_client/client_interceptor.py +19 -119
- flwr/client/grpc_rere_client/connection.py +46 -36
- flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
- flwr/client/message_handler/task_handler.py +0 -17
- flwr/client/rest_client/connection.py +34 -26
- flwr/client/supernode/app.py +18 -72
- flwr/common/args.py +25 -47
- flwr/common/auth_plugin/auth_plugin.py +34 -23
- flwr/common/config.py +166 -16
- flwr/common/constant.py +24 -9
- flwr/common/differential_privacy.py +2 -1
- flwr/common/exit/__init__.py +24 -0
- flwr/common/exit/exit.py +99 -0
- flwr/common/exit/exit_code.py +93 -0
- flwr/common/exit_handlers.py +32 -30
- flwr/common/grpc.py +167 -4
- flwr/common/logger.py +26 -7
- flwr/common/object_ref.py +0 -14
- flwr/common/record/recordset.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
- flwr/common/serde.py +6 -4
- flwr/common/typing.py +20 -0
- flwr/proto/clientappio_pb2.py +1 -1
- flwr/proto/error_pb2.py +1 -1
- flwr/proto/exec_pb2.py +13 -25
- flwr/proto/exec_pb2.pyi +27 -54
- flwr/proto/fab_pb2.py +1 -1
- flwr/proto/fleet_pb2.py +31 -31
- flwr/proto/fleet_pb2.pyi +23 -23
- flwr/proto/fleet_pb2_grpc.py +30 -30
- flwr/proto/fleet_pb2_grpc.pyi +20 -20
- flwr/proto/grpcadapter_pb2.py +1 -1
- flwr/proto/log_pb2.py +1 -1
- flwr/proto/message_pb2.py +1 -1
- flwr/proto/node_pb2.py +3 -3
- flwr/proto/node_pb2.pyi +1 -4
- flwr/proto/recordset_pb2.py +1 -1
- flwr/proto/run_pb2.py +1 -1
- flwr/proto/serverappio_pb2.py +24 -25
- flwr/proto/serverappio_pb2.pyi +26 -32
- flwr/proto/serverappio_pb2_grpc.py +28 -28
- flwr/proto/serverappio_pb2_grpc.pyi +16 -16
- flwr/proto/simulationio_pb2.py +1 -1
- flwr/proto/task_pb2.py +1 -1
- flwr/proto/transport_pb2.py +1 -1
- flwr/server/app.py +116 -128
- flwr/server/compat/app_utils.py +0 -1
- flwr/server/compat/driver_client_proxy.py +1 -2
- flwr/server/driver/grpc_driver.py +32 -27
- flwr/server/driver/inmemory_driver.py +2 -1
- flwr/server/serverapp/app.py +12 -10
- flwr/server/superlink/driver/serverappio_grpc.py +1 -1
- flwr/server/superlink/driver/serverappio_servicer.py +74 -48
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -24
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +110 -168
- flwr/server/superlink/fleet/message_handler/message_handler.py +37 -24
- flwr/server/superlink/fleet/rest_rere/rest_api.py +16 -18
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +45 -75
- flwr/server/superlink/linkstate/linkstate.py +17 -38
- flwr/server/superlink/linkstate/sqlite_linkstate.py +81 -145
- flwr/server/superlink/linkstate/utils.py +18 -8
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/utils/validator.py +9 -34
- flwr/simulation/app.py +4 -6
- flwr/simulation/legacy_app.py +4 -2
- flwr/simulation/run_simulation.py +1 -1
- flwr/simulation/simulationio_connection.py +2 -1
- flwr/superexec/exec_grpc.py +1 -1
- flwr/superexec/exec_servicer.py +23 -2
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/METADATA +8 -8
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/RECORD +103 -97
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/LICENSE +0 -0
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/WHEEL +0 -0
- {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/entry_points.txt +0 -0
|
@@ -18,6 +18,7 @@
|
|
|
18
18
|
from logging import DEBUG, INFO
|
|
19
19
|
|
|
20
20
|
import grpc
|
|
21
|
+
from google.protobuf.json_format import MessageToDict
|
|
21
22
|
|
|
22
23
|
from flwr.common.logger import log
|
|
23
24
|
from flwr.common.typing import InvalidRunStatusException
|
|
@@ -30,10 +31,10 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
30
31
|
DeleteNodeResponse,
|
|
31
32
|
PingRequest,
|
|
32
33
|
PingResponse,
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
34
|
+
PullMessagesRequest,
|
|
35
|
+
PullMessagesResponse,
|
|
36
|
+
PushMessagesRequest,
|
|
37
|
+
PushMessagesResponse,
|
|
37
38
|
)
|
|
38
39
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
39
40
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
@@ -56,13 +57,13 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
56
57
|
) -> CreateNodeResponse:
|
|
57
58
|
"""."""
|
|
58
59
|
log(INFO, "[Fleet.CreateNode] Request ping_interval=%s", request.ping_interval)
|
|
59
|
-
log(DEBUG, "[Fleet.CreateNode] Request: %s", request)
|
|
60
|
+
log(DEBUG, "[Fleet.CreateNode] Request: %s", MessageToDict(request))
|
|
60
61
|
response = message_handler.create_node(
|
|
61
62
|
request=request,
|
|
62
63
|
state=self.state_factory.state(),
|
|
63
64
|
)
|
|
64
65
|
log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
|
|
65
|
-
log(DEBUG, "[Fleet.CreateNode] Response: %s", response)
|
|
66
|
+
log(DEBUG, "[Fleet.CreateNode] Response: %s", MessageToDict(response))
|
|
66
67
|
return response
|
|
67
68
|
|
|
68
69
|
def DeleteNode(
|
|
@@ -70,7 +71,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
70
71
|
) -> DeleteNodeResponse:
|
|
71
72
|
"""."""
|
|
72
73
|
log(INFO, "[Fleet.DeleteNode] Delete node_id=%s", request.node.node_id)
|
|
73
|
-
log(DEBUG, "[Fleet.DeleteNode] Request: %s", request)
|
|
74
|
+
log(DEBUG, "[Fleet.DeleteNode] Request: %s", MessageToDict(request))
|
|
74
75
|
return message_handler.delete_node(
|
|
75
76
|
request=request,
|
|
76
77
|
state=self.state_factory.state(),
|
|
@@ -78,38 +79,38 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
78
79
|
|
|
79
80
|
def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse:
|
|
80
81
|
"""."""
|
|
81
|
-
log(DEBUG, "[Fleet.Ping] Request: %s", request)
|
|
82
|
+
log(DEBUG, "[Fleet.Ping] Request: %s", MessageToDict(request))
|
|
82
83
|
return message_handler.ping(
|
|
83
84
|
request=request,
|
|
84
85
|
state=self.state_factory.state(),
|
|
85
86
|
)
|
|
86
87
|
|
|
87
|
-
def
|
|
88
|
-
self, request:
|
|
89
|
-
) ->
|
|
90
|
-
"""Pull
|
|
91
|
-
log(INFO, "[Fleet.
|
|
92
|
-
log(DEBUG, "[Fleet.
|
|
93
|
-
return message_handler.
|
|
88
|
+
def PullMessages(
|
|
89
|
+
self, request: PullMessagesRequest, context: grpc.ServicerContext
|
|
90
|
+
) -> PullMessagesResponse:
|
|
91
|
+
"""Pull Messages."""
|
|
92
|
+
log(INFO, "[Fleet.PullMessages] node_id=%s", request.node.node_id)
|
|
93
|
+
log(DEBUG, "[Fleet.PullMessages] Request: %s", MessageToDict(request))
|
|
94
|
+
return message_handler.pull_messages(
|
|
94
95
|
request=request,
|
|
95
96
|
state=self.state_factory.state(),
|
|
96
97
|
)
|
|
97
98
|
|
|
98
|
-
def
|
|
99
|
-
self, request:
|
|
100
|
-
) ->
|
|
101
|
-
"""Push
|
|
102
|
-
if request.
|
|
99
|
+
def PushMessages(
|
|
100
|
+
self, request: PushMessagesRequest, context: grpc.ServicerContext
|
|
101
|
+
) -> PushMessagesResponse:
|
|
102
|
+
"""Push Messages."""
|
|
103
|
+
if request.messages_list:
|
|
103
104
|
log(
|
|
104
105
|
INFO,
|
|
105
|
-
"[Fleet.
|
|
106
|
-
request.
|
|
106
|
+
"[Fleet.PushMessages] Push results from node_id=%s",
|
|
107
|
+
request.messages_list[0].metadata.src_node_id,
|
|
107
108
|
)
|
|
108
109
|
else:
|
|
109
|
-
log(INFO, "[Fleet.
|
|
110
|
+
log(INFO, "[Fleet.PushMessages] No task results to push")
|
|
110
111
|
|
|
111
112
|
try:
|
|
112
|
-
res = message_handler.
|
|
113
|
+
res = message_handler.push_messages(
|
|
113
114
|
request=request,
|
|
114
115
|
state=self.state_factory.state(),
|
|
115
116
|
)
|
|
@@ -15,93 +15,62 @@
|
|
|
15
15
|
"""Flower server interceptor."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import
|
|
19
|
-
from
|
|
20
|
-
from logging import INFO, WARNING
|
|
21
|
-
from typing import Any, Callable, Optional, Union
|
|
18
|
+
import datetime
|
|
19
|
+
from typing import Any, Callable, Optional, cast
|
|
22
20
|
|
|
23
21
|
import grpc
|
|
24
|
-
from
|
|
25
|
-
|
|
26
|
-
from flwr.common
|
|
22
|
+
from google.protobuf.message import Message as GrpcMessage
|
|
23
|
+
|
|
24
|
+
from flwr.common import now
|
|
25
|
+
from flwr.common.constant import (
|
|
26
|
+
PUBLIC_KEY_HEADER,
|
|
27
|
+
SIGNATURE_HEADER,
|
|
28
|
+
SYSTEM_TIME_TOLERANCE,
|
|
29
|
+
TIMESTAMP_HEADER,
|
|
30
|
+
TIMESTAMP_TOLERANCE,
|
|
31
|
+
)
|
|
27
32
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
28
|
-
bytes_to_private_key,
|
|
29
33
|
bytes_to_public_key,
|
|
30
|
-
|
|
31
|
-
verify_hmac,
|
|
34
|
+
verify_signature,
|
|
32
35
|
)
|
|
33
|
-
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
34
36
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
35
37
|
CreateNodeRequest,
|
|
36
38
|
CreateNodeResponse,
|
|
37
|
-
DeleteNodeRequest,
|
|
38
|
-
DeleteNodeResponse,
|
|
39
|
-
PingRequest,
|
|
40
|
-
PingResponse,
|
|
41
|
-
PullTaskInsRequest,
|
|
42
|
-
PullTaskInsResponse,
|
|
43
|
-
PushTaskResRequest,
|
|
44
|
-
PushTaskResResponse,
|
|
45
39
|
)
|
|
46
|
-
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
47
|
-
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
48
40
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
49
41
|
|
|
50
|
-
|
|
51
|
-
|
|
42
|
+
MIN_TIMESTAMP_DIFF = -SYSTEM_TIME_TOLERANCE
|
|
43
|
+
MAX_TIMESTAMP_DIFF = TIMESTAMP_TOLERANCE + SYSTEM_TIME_TOLERANCE
|
|
52
44
|
|
|
53
|
-
Request = Union[
|
|
54
|
-
CreateNodeRequest,
|
|
55
|
-
DeleteNodeRequest,
|
|
56
|
-
PullTaskInsRequest,
|
|
57
|
-
PushTaskResRequest,
|
|
58
|
-
GetRunRequest,
|
|
59
|
-
PingRequest,
|
|
60
|
-
GetFabRequest,
|
|
61
|
-
]
|
|
62
|
-
|
|
63
|
-
Response = Union[
|
|
64
|
-
CreateNodeResponse,
|
|
65
|
-
DeleteNodeResponse,
|
|
66
|
-
PullTaskInsResponse,
|
|
67
|
-
PushTaskResResponse,
|
|
68
|
-
GetRunResponse,
|
|
69
|
-
PingResponse,
|
|
70
|
-
GetFabResponse,
|
|
71
|
-
]
|
|
72
45
|
|
|
46
|
+
def _unary_unary_rpc_terminator(
|
|
47
|
+
message: str, code: Any = grpc.StatusCode.UNAUTHENTICATED
|
|
48
|
+
) -> grpc.RpcMethodHandler:
|
|
49
|
+
def terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMessage:
|
|
50
|
+
context.abort(code, message)
|
|
51
|
+
raise RuntimeError("Should not reach this point") # Make mypy happy
|
|
73
52
|
|
|
74
|
-
|
|
75
|
-
key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]]
|
|
76
|
-
) -> bytes:
|
|
77
|
-
value = next((value for key, value in tuples if key == key_string), "")
|
|
78
|
-
if isinstance(value, str):
|
|
79
|
-
return value.encode()
|
|
80
|
-
|
|
81
|
-
return value
|
|
53
|
+
return grpc.unary_unary_rpc_method_handler(terminate)
|
|
82
54
|
|
|
83
55
|
|
|
84
56
|
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
85
|
-
"""Server interceptor for node authentication.
|
|
86
|
-
|
|
87
|
-
|
|
57
|
+
"""Server interceptor for node authentication.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
state_factory : LinkStateFactory
|
|
62
|
+
A factory for creating new instances of LinkState.
|
|
63
|
+
auto_auth : bool (default: False)
|
|
64
|
+
If True, nodes are authenticated without requiring their public keys to be
|
|
65
|
+
pre-stored in the LinkState. If False, only nodes with pre-stored public keys
|
|
66
|
+
can be authenticated.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(self, state_factory: LinkStateFactory, auto_auth: bool = False):
|
|
88
70
|
self.state_factory = state_factory
|
|
89
|
-
|
|
71
|
+
self.auto_auth = auto_auth
|
|
90
72
|
|
|
91
|
-
|
|
92
|
-
if len(self.node_public_keys) == 0:
|
|
93
|
-
log(WARNING, "Authentication enabled, but no known public keys configured")
|
|
94
|
-
|
|
95
|
-
private_key = state.get_server_private_key()
|
|
96
|
-
public_key = state.get_server_public_key()
|
|
97
|
-
|
|
98
|
-
if private_key is None or public_key is None:
|
|
99
|
-
raise ValueError("Error loading authentication keys")
|
|
100
|
-
|
|
101
|
-
self.server_private_key = bytes_to_private_key(private_key)
|
|
102
|
-
self.encoded_server_public_key = base64.urlsafe_b64encode(public_key)
|
|
103
|
-
|
|
104
|
-
def intercept_service(
|
|
73
|
+
def intercept_service( # pylint: disable=too-many-return-statements
|
|
105
74
|
self,
|
|
106
75
|
continuation: Callable[[Any], Any],
|
|
107
76
|
handler_call_details: grpc.HandlerCallDetails,
|
|
@@ -112,116 +81,89 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
112
81
|
metadata sent by the node. Continue RPC call if node is authenticated, else,
|
|
113
82
|
terminate RPC call by setting context to abort.
|
|
114
83
|
"""
|
|
84
|
+
# Filter out non-Fleet service calls
|
|
85
|
+
if not handler_call_details.method.startswith("/flwr.proto.Fleet/"):
|
|
86
|
+
return _unary_unary_rpc_terminator(
|
|
87
|
+
"This request should be sent to a different service.",
|
|
88
|
+
grpc.StatusCode.FAILED_PRECONDITION,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
state = self.state_factory.state()
|
|
92
|
+
metadata_dict = dict(handler_call_details.invocation_metadata)
|
|
93
|
+
|
|
94
|
+
# Retrieve info from the metadata
|
|
95
|
+
try:
|
|
96
|
+
node_pk_bytes = cast(bytes, metadata_dict[PUBLIC_KEY_HEADER])
|
|
97
|
+
timestamp_iso = cast(str, metadata_dict[TIMESTAMP_HEADER])
|
|
98
|
+
signature = cast(bytes, metadata_dict[SIGNATURE_HEADER])
|
|
99
|
+
except KeyError:
|
|
100
|
+
return _unary_unary_rpc_terminator("Missing authentication metadata")
|
|
101
|
+
|
|
102
|
+
if not self.auto_auth:
|
|
103
|
+
# Abort the RPC call if the node public key is not found
|
|
104
|
+
if node_pk_bytes not in state.get_node_public_keys():
|
|
105
|
+
return _unary_unary_rpc_terminator("Public key not recognized")
|
|
106
|
+
|
|
107
|
+
# Verify the signature
|
|
108
|
+
node_pk = bytes_to_public_key(node_pk_bytes)
|
|
109
|
+
if not verify_signature(node_pk, timestamp_iso.encode("ascii"), signature):
|
|
110
|
+
return _unary_unary_rpc_terminator("Invalid signature")
|
|
111
|
+
|
|
112
|
+
# Verify the timestamp
|
|
113
|
+
current = now()
|
|
114
|
+
time_diff = current - datetime.datetime.fromisoformat(timestamp_iso)
|
|
115
|
+
# Abort the RPC call if the timestamp is too old or in the future
|
|
116
|
+
if not MIN_TIMESTAMP_DIFF < time_diff.total_seconds() < MAX_TIMESTAMP_DIFF:
|
|
117
|
+
return _unary_unary_rpc_terminator("Invalid timestamp")
|
|
118
|
+
|
|
119
|
+
# Continue the RPC call
|
|
120
|
+
expected_node_id = state.get_node_id(node_pk_bytes)
|
|
121
|
+
if not handler_call_details.method.endswith("CreateNode"):
|
|
122
|
+
# All calls, except for `CreateNode`, must provide a public key that is
|
|
123
|
+
# already mapped to a `node_id` (in `LinkState`)
|
|
124
|
+
if expected_node_id is None:
|
|
125
|
+
return _unary_unary_rpc_terminator("Invalid node ID")
|
|
115
126
|
# One of the method handlers in
|
|
116
127
|
# `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
|
|
117
128
|
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
|
|
118
|
-
return self.
|
|
129
|
+
return self._wrap_method_handler(
|
|
130
|
+
method_handler, expected_node_id, node_pk_bytes
|
|
131
|
+
)
|
|
119
132
|
|
|
120
|
-
def
|
|
121
|
-
self,
|
|
133
|
+
def _wrap_method_handler(
|
|
134
|
+
self,
|
|
135
|
+
method_handler: grpc.RpcMethodHandler,
|
|
136
|
+
expected_node_id: Optional[int],
|
|
137
|
+
node_public_key: bytes,
|
|
122
138
|
) -> grpc.RpcMethodHandler:
|
|
123
139
|
def _generic_method_handler(
|
|
124
|
-
request:
|
|
140
|
+
request: GrpcMessage,
|
|
125
141
|
context: grpc.ServicerContext,
|
|
126
|
-
) ->
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
_get_value_from_tuples(
|
|
149
|
-
_AUTH_TOKEN_HEADER, context.invocation_metadata()
|
|
150
|
-
)
|
|
151
|
-
)
|
|
152
|
-
public_key = bytes_to_public_key(node_public_key_bytes)
|
|
153
|
-
|
|
154
|
-
if not self._verify_hmac(public_key, request, hmac_value):
|
|
155
|
-
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
156
|
-
|
|
157
|
-
# Verify node_id
|
|
158
|
-
node_id = self.state_factory.state().get_node_id(node_public_key_bytes)
|
|
159
|
-
|
|
160
|
-
if not self._verify_node_id(node_id, request):
|
|
161
|
-
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
|
162
|
-
|
|
163
|
-
return method_handler.unary_unary(request, context) # type: ignore
|
|
142
|
+
) -> GrpcMessage:
|
|
143
|
+
# Verify the node ID
|
|
144
|
+
if not isinstance(request, CreateNodeRequest):
|
|
145
|
+
try:
|
|
146
|
+
if request.node.node_id != expected_node_id: # type: ignore
|
|
147
|
+
raise ValueError
|
|
148
|
+
except (AttributeError, ValueError):
|
|
149
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
|
|
150
|
+
|
|
151
|
+
response: GrpcMessage = method_handler.unary_unary(request, context)
|
|
152
|
+
|
|
153
|
+
# Set the public key after a successful CreateNode request
|
|
154
|
+
if isinstance(response, CreateNodeResponse):
|
|
155
|
+
state = self.state_factory.state()
|
|
156
|
+
try:
|
|
157
|
+
state.set_node_public_key(response.node.node_id, node_public_key)
|
|
158
|
+
except ValueError as e:
|
|
159
|
+
# Remove newly created node if setting the public key fails
|
|
160
|
+
state.delete_node(response.node.node_id)
|
|
161
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
|
|
162
|
+
|
|
163
|
+
return response
|
|
164
164
|
|
|
165
165
|
return grpc.unary_unary_rpc_method_handler(
|
|
166
166
|
_generic_method_handler,
|
|
167
167
|
request_deserializer=method_handler.request_deserializer,
|
|
168
168
|
response_serializer=method_handler.response_serializer,
|
|
169
169
|
)
|
|
170
|
-
|
|
171
|
-
def _verify_node_id(
|
|
172
|
-
self,
|
|
173
|
-
node_id: Optional[int],
|
|
174
|
-
request: Union[
|
|
175
|
-
DeleteNodeRequest,
|
|
176
|
-
PullTaskInsRequest,
|
|
177
|
-
PushTaskResRequest,
|
|
178
|
-
GetRunRequest,
|
|
179
|
-
PingRequest,
|
|
180
|
-
GetFabRequest,
|
|
181
|
-
],
|
|
182
|
-
) -> bool:
|
|
183
|
-
if node_id is None:
|
|
184
|
-
return False
|
|
185
|
-
if isinstance(request, PushTaskResRequest):
|
|
186
|
-
if len(request.task_res_list) == 0:
|
|
187
|
-
return False
|
|
188
|
-
return request.task_res_list[0].task.producer.node_id == node_id
|
|
189
|
-
if isinstance(request, GetRunRequest):
|
|
190
|
-
return node_id in self.state_factory.state().get_nodes(request.run_id)
|
|
191
|
-
return request.node.node_id == node_id
|
|
192
|
-
|
|
193
|
-
def _verify_hmac(
|
|
194
|
-
self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
|
|
195
|
-
) -> bool:
|
|
196
|
-
shared_secret = generate_shared_key(self.server_private_key, public_key)
|
|
197
|
-
message_bytes = request.SerializeToString(deterministic=True)
|
|
198
|
-
return verify_hmac(shared_secret, message_bytes, hmac_value)
|
|
199
|
-
|
|
200
|
-
def _create_authenticated_node(
|
|
201
|
-
self,
|
|
202
|
-
public_key_bytes: bytes,
|
|
203
|
-
request: CreateNodeRequest,
|
|
204
|
-
context: grpc.ServicerContext,
|
|
205
|
-
) -> CreateNodeResponse:
|
|
206
|
-
context.send_initial_metadata(
|
|
207
|
-
(
|
|
208
|
-
(
|
|
209
|
-
_PUBLIC_KEY_HEADER,
|
|
210
|
-
self.encoded_server_public_key,
|
|
211
|
-
),
|
|
212
|
-
)
|
|
213
|
-
)
|
|
214
|
-
state = self.state_factory.state()
|
|
215
|
-
node_id = state.get_node_id(public_key_bytes)
|
|
216
|
-
|
|
217
|
-
# Handle `CreateNode` here instead of calling the default method handler
|
|
218
|
-
# Return previously assigned `node_id` for the provided `public_key`
|
|
219
|
-
if node_id is not None:
|
|
220
|
-
state.acknowledge_ping(node_id, request.ping_interval)
|
|
221
|
-
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
222
|
-
|
|
223
|
-
# No `node_id` exists for the provided `public_key`
|
|
224
|
-
# Handle `CreateNode` here instead of calling the default method handler
|
|
225
|
-
# Note: the innermost `CreateNode` method will never be called
|
|
226
|
-
node_id = state.create_node(request.ping_interval, public_key_bytes)
|
|
227
|
-
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
|
@@ -20,7 +20,14 @@ from typing import Optional
|
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
22
|
from flwr.common.constant import Status
|
|
23
|
-
from flwr.common.serde import
|
|
23
|
+
from flwr.common.serde import (
|
|
24
|
+
fab_to_proto,
|
|
25
|
+
message_from_proto,
|
|
26
|
+
message_from_taskins,
|
|
27
|
+
message_to_proto,
|
|
28
|
+
message_to_taskres,
|
|
29
|
+
user_config_to_proto,
|
|
30
|
+
)
|
|
24
31
|
from flwr.common.typing import Fab, InvalidRunStatusException
|
|
25
32
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
26
33
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
@@ -30,10 +37,10 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
30
37
|
DeleteNodeResponse,
|
|
31
38
|
PingRequest,
|
|
32
39
|
PingResponse,
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
40
|
+
PullMessagesRequest,
|
|
41
|
+
PullMessagesResponse,
|
|
42
|
+
PushMessagesRequest,
|
|
43
|
+
PushMessagesResponse,
|
|
37
44
|
Reconnect,
|
|
38
45
|
)
|
|
39
46
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
@@ -42,7 +49,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
42
49
|
GetRunResponse,
|
|
43
50
|
Run,
|
|
44
51
|
)
|
|
45
|
-
from flwr.proto.task_pb2 import TaskIns
|
|
52
|
+
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
46
53
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
47
54
|
from flwr.server.superlink.linkstate import LinkState
|
|
48
55
|
from flwr.server.superlink.utils import check_abort
|
|
@@ -55,13 +62,13 @@ def create_node(
|
|
|
55
62
|
"""."""
|
|
56
63
|
# Create node
|
|
57
64
|
node_id = state.create_node(ping_interval=request.ping_interval)
|
|
58
|
-
return CreateNodeResponse(node=Node(node_id=node_id
|
|
65
|
+
return CreateNodeResponse(node=Node(node_id=node_id))
|
|
59
66
|
|
|
60
67
|
|
|
61
68
|
def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
|
|
62
69
|
"""."""
|
|
63
70
|
# Validate node_id
|
|
64
|
-
if request.node.
|
|
71
|
+
if request.node.node_id == 0: # i.e. unset `node_id`
|
|
65
72
|
return DeleteNodeResponse()
|
|
66
73
|
|
|
67
74
|
# Update state
|
|
@@ -78,27 +85,33 @@ def ping(
|
|
|
78
85
|
return PingResponse(success=res)
|
|
79
86
|
|
|
80
87
|
|
|
81
|
-
def
|
|
82
|
-
|
|
88
|
+
def pull_messages(
|
|
89
|
+
request: PullMessagesRequest, state: LinkState
|
|
90
|
+
) -> PullMessagesResponse:
|
|
91
|
+
"""Pull Messages handler."""
|
|
83
92
|
# Get node_id if client node is not anonymous
|
|
84
93
|
node = request.node # pylint: disable=no-member
|
|
85
|
-
node_id:
|
|
94
|
+
node_id: int = node.node_id
|
|
86
95
|
|
|
87
96
|
# Retrieve TaskIns from State
|
|
88
97
|
task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)
|
|
89
98
|
|
|
90
|
-
#
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
99
|
+
# Convert to Messages
|
|
100
|
+
msg_proto = []
|
|
101
|
+
for task_ins in task_ins_list:
|
|
102
|
+
msg = message_from_taskins(task_ins)
|
|
103
|
+
msg_proto.append(message_to_proto(msg))
|
|
104
|
+
|
|
105
|
+
return PullMessagesResponse(messages_list=msg_proto)
|
|
95
106
|
|
|
96
107
|
|
|
97
|
-
def
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
#
|
|
108
|
+
def push_messages(
|
|
109
|
+
request: PushMessagesRequest, state: LinkState
|
|
110
|
+
) -> PushMessagesResponse:
|
|
111
|
+
"""Push Messages handler."""
|
|
112
|
+
# Convert Message to TaskRes
|
|
113
|
+
msg = message_from_proto(message_proto=request.messages_list[0])
|
|
114
|
+
task_res = message_to_taskres(msg)
|
|
102
115
|
|
|
103
116
|
# Abort if the run is not running
|
|
104
117
|
abort_msg = check_abort(
|
|
@@ -113,12 +126,12 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
|
|
|
113
126
|
task_res.task.pushed_at = time.time()
|
|
114
127
|
|
|
115
128
|
# Store TaskRes in State
|
|
116
|
-
|
|
129
|
+
message_id: Optional[UUID] = state.store_task_res(task_res=task_res)
|
|
117
130
|
|
|
118
131
|
# Build response
|
|
119
|
-
response =
|
|
132
|
+
response = PushMessagesResponse(
|
|
120
133
|
reconnect=Reconnect(reconnect=5),
|
|
121
|
-
results={str(
|
|
134
|
+
results={str(message_id): 0},
|
|
122
135
|
)
|
|
123
136
|
return response
|
|
124
137
|
|
|
@@ -17,13 +17,12 @@
|
|
|
17
17
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
|
-
import sys
|
|
21
20
|
from collections.abc import Awaitable
|
|
22
21
|
from typing import Callable, TypeVar, cast
|
|
23
22
|
|
|
24
23
|
from google.protobuf.message import Message as GrpcMessage
|
|
25
24
|
|
|
26
|
-
from flwr.common.
|
|
25
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
|
27
26
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
28
27
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
29
28
|
CreateNodeRequest,
|
|
@@ -32,10 +31,10 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
32
31
|
DeleteNodeResponse,
|
|
33
32
|
PingRequest,
|
|
34
33
|
PingResponse,
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
34
|
+
PullMessagesRequest,
|
|
35
|
+
PullMessagesResponse,
|
|
36
|
+
PushMessagesRequest,
|
|
37
|
+
PushMessagesResponse,
|
|
39
38
|
)
|
|
40
39
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
41
40
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
@@ -51,7 +50,7 @@ try:
|
|
|
51
50
|
from starlette.responses import Response
|
|
52
51
|
from starlette.routing import Route
|
|
53
52
|
except ModuleNotFoundError:
|
|
54
|
-
|
|
53
|
+
flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
|
|
55
54
|
|
|
56
55
|
|
|
57
56
|
GrpcRequest = TypeVar("GrpcRequest", bound=GrpcMessage)
|
|
@@ -107,25 +106,24 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
|
107
106
|
return message_handler.delete_node(request=request, state=state)
|
|
108
107
|
|
|
109
108
|
|
|
110
|
-
@rest_request_response(
|
|
111
|
-
async def
|
|
112
|
-
"""Pull
|
|
109
|
+
@rest_request_response(PullMessagesRequest)
|
|
110
|
+
async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
|
|
111
|
+
"""Pull PullMessages."""
|
|
113
112
|
# Get state from app
|
|
114
113
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
115
114
|
|
|
116
115
|
# Handle message
|
|
117
|
-
return message_handler.
|
|
116
|
+
return message_handler.pull_messages(request=request, state=state)
|
|
118
117
|
|
|
119
118
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
"""Push TaskRes."""
|
|
119
|
+
@rest_request_response(PushMessagesRequest)
|
|
120
|
+
async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
|
|
121
|
+
"""Pull PushMessages."""
|
|
124
122
|
# Get state from app
|
|
125
123
|
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
|
126
124
|
|
|
127
125
|
# Handle message
|
|
128
|
-
return message_handler.
|
|
126
|
+
return message_handler.push_messages(request=request, state=state)
|
|
129
127
|
|
|
130
128
|
|
|
131
129
|
@rest_request_response(PingRequest)
|
|
@@ -164,8 +162,8 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
|
|
|
164
162
|
routes = [
|
|
165
163
|
Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
|
|
166
164
|
Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
|
|
167
|
-
Route("/api/v0/fleet/pull-
|
|
168
|
-
Route("/api/v0/fleet/push-
|
|
165
|
+
Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]),
|
|
166
|
+
Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]),
|
|
169
167
|
Route("/api/v0/fleet/ping", ping, methods=["POST"]),
|
|
170
168
|
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
|
|
171
169
|
Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
|
|
@@ -182,8 +182,8 @@ def run_api(
|
|
|
182
182
|
f_stop: threading.Event,
|
|
183
183
|
) -> None:
|
|
184
184
|
"""Run the VCE."""
|
|
185
|
-
taskins_queue:
|
|
186
|
-
taskres_queue:
|
|
185
|
+
taskins_queue: Queue[TaskIns] = Queue()
|
|
186
|
+
taskres_queue: Queue[TaskRes] = Queue()
|
|
187
187
|
|
|
188
188
|
try:
|
|
189
189
|
|