flwr-nightly 1.15.0.dev20250104__py3-none-any.whl → 1.15.0.dev20250123__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/cli_user_auth_interceptor.py +6 -2
- flwr/cli/config_utils.py +23 -146
- flwr/cli/constant.py +27 -0
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +17 -2
- flwr/cli/login/login.py +20 -5
- flwr/cli/ls.py +10 -2
- flwr/cli/run/run.py +20 -10
- flwr/cli/stop.py +9 -1
- flwr/cli/utils.py +4 -4
- flwr/client/app.py +36 -48
- flwr/client/clientapp/app.py +4 -6
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/grpc_client/connection.py +0 -6
- flwr/client/grpc_rere_client/client_interceptor.py +19 -119
- flwr/client/grpc_rere_client/connection.py +34 -24
- flwr/client/grpc_rere_client/grpc_adapter.py +16 -0
- flwr/client/rest_client/connection.py +34 -26
- flwr/client/supernode/app.py +14 -20
- flwr/common/auth_plugin/auth_plugin.py +34 -23
- flwr/common/config.py +152 -15
- flwr/common/constant.py +11 -8
- flwr/common/exit/__init__.py +24 -0
- flwr/common/exit/exit.py +99 -0
- flwr/common/exit/exit_code.py +93 -0
- flwr/common/exit_handlers.py +24 -10
- flwr/common/grpc.py +161 -3
- flwr/common/logger.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
- flwr/common/serde.py +6 -4
- flwr/common/typing.py +20 -0
- flwr/proto/clientappio_pb2.py +13 -3
- flwr/proto/clientappio_pb2_grpc.py +63 -12
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/exec_pb2.py +27 -29
- flwr/proto/exec_pb2.pyi +27 -54
- flwr/proto/exec_pb2_grpc.py +105 -24
- flwr/proto/fab_pb2.py +13 -3
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fleet_pb2.py +54 -31
- flwr/proto/fleet_pb2.pyi +84 -0
- flwr/proto/fleet_pb2_grpc.py +207 -28
- flwr/proto/fleet_pb2_grpc.pyi +26 -0
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +1 -4
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/recordset_pb2.py +18 -8
- flwr/proto/recordset_pb2_grpc.py +20 -0
- flwr/proto/run_pb2.py +16 -6
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/serverappio_pb2.py +32 -14
- flwr/proto/serverappio_pb2.pyi +56 -0
- flwr/proto/serverappio_pb2_grpc.py +261 -44
- flwr/proto/serverappio_pb2_grpc.pyi +20 -0
- flwr/proto/simulationio_pb2.py +13 -3
- flwr/proto/simulationio_pb2_grpc.py +105 -24
- flwr/proto/task_pb2.py +13 -3
- flwr/proto/task_pb2_grpc.py +20 -0
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/server/app.py +87 -38
- flwr/server/compat/app_utils.py +0 -1
- flwr/server/compat/driver_client_proxy.py +1 -2
- flwr/server/driver/grpc_driver.py +5 -2
- flwr/server/driver/inmemory_driver.py +2 -1
- flwr/server/serverapp/app.py +5 -6
- flwr/server/superlink/driver/serverappio_grpc.py +1 -1
- flwr/server/superlink/driver/serverappio_servicer.py +132 -14
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +38 -0
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +95 -168
- flwr/server/superlink/fleet/message_handler/message_handler.py +66 -5
- flwr/server/superlink/fleet/rest_rere/rest_api.py +28 -3
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +40 -48
- flwr/server/superlink/linkstate/linkstate.py +15 -22
- flwr/server/superlink/linkstate/sqlite_linkstate.py +80 -99
- flwr/server/superlink/linkstate/utils.py +18 -8
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/utils/validator.py +9 -34
- flwr/simulation/app.py +4 -6
- flwr/simulation/legacy_app.py +4 -2
- flwr/simulation/run_simulation.py +1 -1
- flwr/superexec/exec_grpc.py +1 -1
- flwr/superexec/exec_servicer.py +23 -2
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/METADATA +7 -7
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/RECORD +98 -94
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/entry_points.txt +0 -0
@@ -30,8 +30,12 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
30
30
|
DeleteNodeResponse,
|
31
31
|
PingRequest,
|
32
32
|
PingResponse,
|
33
|
+
PullMessagesRequest,
|
34
|
+
PullMessagesResponse,
|
33
35
|
PullTaskInsRequest,
|
34
36
|
PullTaskInsResponse,
|
37
|
+
PushMessagesRequest,
|
38
|
+
PushMessagesResponse,
|
35
39
|
PushTaskResRequest,
|
36
40
|
PushTaskResResponse,
|
37
41
|
)
|
@@ -95,6 +99,17 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
95
99
|
state=self.state_factory.state(),
|
96
100
|
)
|
97
101
|
|
102
|
+
def PullMessages(
|
103
|
+
self, request: PullMessagesRequest, context: grpc.ServicerContext
|
104
|
+
) -> PullMessagesResponse:
|
105
|
+
"""Pull Messages."""
|
106
|
+
log(INFO, "[Fleet.PullMessages] node_id=%s", request.node.node_id)
|
107
|
+
log(DEBUG, "[Fleet.PullMessages] Request: %s", request)
|
108
|
+
return message_handler.pull_messages(
|
109
|
+
request=request,
|
110
|
+
state=self.state_factory.state(),
|
111
|
+
)
|
112
|
+
|
98
113
|
def PushTaskRes(
|
99
114
|
self, request: PushTaskResRequest, context: grpc.ServicerContext
|
100
115
|
) -> PushTaskResResponse:
|
@@ -118,6 +133,29 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
118
133
|
|
119
134
|
return res
|
120
135
|
|
136
|
+
def PushMessages(
|
137
|
+
self, request: PushMessagesRequest, context: grpc.ServicerContext
|
138
|
+
) -> PushMessagesResponse:
|
139
|
+
"""Push Messages."""
|
140
|
+
if request.messages_list:
|
141
|
+
log(
|
142
|
+
INFO,
|
143
|
+
"[Fleet.PushMessages] Push results from node_id=%s",
|
144
|
+
request.messages_list[0].metadata.src_node_id,
|
145
|
+
)
|
146
|
+
else:
|
147
|
+
log(INFO, "[Fleet.PushMessages] No task results to push")
|
148
|
+
|
149
|
+
try:
|
150
|
+
res = message_handler.push_messages(
|
151
|
+
request=request,
|
152
|
+
state=self.state_factory.state(),
|
153
|
+
)
|
154
|
+
except InvalidRunStatusException as e:
|
155
|
+
abort_grpc_context(e.message, context)
|
156
|
+
|
157
|
+
return res
|
158
|
+
|
121
159
|
def GetRun(
|
122
160
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
123
161
|
) -> GetRunResponse:
|
@@ -15,91 +15,54 @@
|
|
15
15
|
"""Flower server interceptor."""
|
16
16
|
|
17
17
|
|
18
|
-
import
|
19
|
-
from
|
20
|
-
from logging import INFO, WARNING
|
21
|
-
from typing import Any, Callable, Optional, Union
|
18
|
+
import datetime
|
19
|
+
from typing import Any, Callable, Optional, cast
|
22
20
|
|
23
21
|
import grpc
|
24
|
-
from
|
25
|
-
|
26
|
-
from flwr.common
|
22
|
+
from google.protobuf.message import Message as GrpcMessage
|
23
|
+
|
24
|
+
from flwr.common import now
|
25
|
+
from flwr.common.constant import (
|
26
|
+
PUBLIC_KEY_HEADER,
|
27
|
+
SIGNATURE_HEADER,
|
28
|
+
TIMESTAMP_HEADER,
|
29
|
+
TIMESTAMP_TOLERANCE,
|
30
|
+
)
|
27
31
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
28
|
-
bytes_to_private_key,
|
29
32
|
bytes_to_public_key,
|
30
|
-
|
31
|
-
verify_hmac,
|
33
|
+
verify_signature,
|
32
34
|
)
|
33
|
-
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
34
35
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
35
36
|
CreateNodeRequest,
|
36
37
|
CreateNodeResponse,
|
37
|
-
DeleteNodeRequest,
|
38
|
-
DeleteNodeResponse,
|
39
|
-
PingRequest,
|
40
|
-
PingResponse,
|
41
|
-
PullTaskInsRequest,
|
42
|
-
PullTaskInsResponse,
|
43
|
-
PushTaskResRequest,
|
44
|
-
PushTaskResResponse,
|
45
38
|
)
|
46
|
-
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
47
|
-
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
48
39
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
49
40
|
|
50
|
-
_PUBLIC_KEY_HEADER = "public-key"
|
51
|
-
_AUTH_TOKEN_HEADER = "auth-token"
|
52
|
-
|
53
|
-
Request = Union[
|
54
|
-
CreateNodeRequest,
|
55
|
-
DeleteNodeRequest,
|
56
|
-
PullTaskInsRequest,
|
57
|
-
PushTaskResRequest,
|
58
|
-
GetRunRequest,
|
59
|
-
PingRequest,
|
60
|
-
GetFabRequest,
|
61
|
-
]
|
62
|
-
|
63
|
-
Response = Union[
|
64
|
-
CreateNodeResponse,
|
65
|
-
DeleteNodeResponse,
|
66
|
-
PullTaskInsResponse,
|
67
|
-
PushTaskResResponse,
|
68
|
-
GetRunResponse,
|
69
|
-
PingResponse,
|
70
|
-
GetFabResponse,
|
71
|
-
]
|
72
|
-
|
73
41
|
|
74
|
-
def
|
75
|
-
|
76
|
-
)
|
77
|
-
|
78
|
-
if isinstance(value, str):
|
79
|
-
return value.encode()
|
42
|
+
def _unary_unary_rpc_terminator(message: str) -> grpc.RpcMethodHandler:
|
43
|
+
def terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMessage:
|
44
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
|
45
|
+
raise RuntimeError("Should not reach this point") # Make mypy happy
|
80
46
|
|
81
|
-
return
|
47
|
+
return grpc.unary_unary_rpc_method_handler(terminate)
|
82
48
|
|
83
49
|
|
84
50
|
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
85
|
-
"""Server interceptor for node authentication.
|
86
|
-
|
87
|
-
|
51
|
+
"""Server interceptor for node authentication.
|
52
|
+
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
state_factory : LinkStateFactory
|
56
|
+
A factory for creating new instances of LinkState.
|
57
|
+
auto_auth : bool (default: False)
|
58
|
+
If True, nodes are authenticated without requiring their public keys to be
|
59
|
+
pre-stored in the LinkState. If False, only nodes with pre-stored public keys
|
60
|
+
can be authenticated.
|
61
|
+
"""
|
62
|
+
|
63
|
+
def __init__(self, state_factory: LinkStateFactory, auto_auth: bool = False):
|
88
64
|
self.state_factory = state_factory
|
89
|
-
|
90
|
-
|
91
|
-
self.node_public_keys = state.get_node_public_keys()
|
92
|
-
if len(self.node_public_keys) == 0:
|
93
|
-
log(WARNING, "Authentication enabled, but no known public keys configured")
|
94
|
-
|
95
|
-
private_key = state.get_server_private_key()
|
96
|
-
public_key = state.get_server_public_key()
|
97
|
-
|
98
|
-
if private_key is None or public_key is None:
|
99
|
-
raise ValueError("Error loading authentication keys")
|
100
|
-
|
101
|
-
self.server_private_key = bytes_to_private_key(private_key)
|
102
|
-
self.encoded_server_public_key = base64.urlsafe_b64encode(public_key)
|
65
|
+
self.auto_auth = auto_auth
|
103
66
|
|
104
67
|
def intercept_service(
|
105
68
|
self,
|
@@ -112,116 +75,80 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
112
75
|
metadata sent by the node. Continue RPC call if node is authenticated, else,
|
113
76
|
terminate RPC call by setting context to abort.
|
114
77
|
"""
|
78
|
+
state = self.state_factory.state()
|
79
|
+
metadata_dict = dict(handler_call_details.invocation_metadata)
|
80
|
+
|
81
|
+
# Retrieve info from the metadata
|
82
|
+
try:
|
83
|
+
node_pk_bytes = cast(bytes, metadata_dict[PUBLIC_KEY_HEADER])
|
84
|
+
timestamp_iso = cast(str, metadata_dict[TIMESTAMP_HEADER])
|
85
|
+
signature = cast(bytes, metadata_dict[SIGNATURE_HEADER])
|
86
|
+
except KeyError:
|
87
|
+
return _unary_unary_rpc_terminator("Missing authentication metadata")
|
88
|
+
|
89
|
+
if not self.auto_auth:
|
90
|
+
# Abort the RPC call if the node public key is not found
|
91
|
+
if node_pk_bytes not in state.get_node_public_keys():
|
92
|
+
return _unary_unary_rpc_terminator("Public key not recognized")
|
93
|
+
|
94
|
+
# Verify the signature
|
95
|
+
node_pk = bytes_to_public_key(node_pk_bytes)
|
96
|
+
if not verify_signature(node_pk, timestamp_iso.encode("ascii"), signature):
|
97
|
+
return _unary_unary_rpc_terminator("Invalid signature")
|
98
|
+
|
99
|
+
# Verify the timestamp
|
100
|
+
current = now()
|
101
|
+
time_diff = current - datetime.datetime.fromisoformat(timestamp_iso)
|
102
|
+
# Abort the RPC call if the timestamp is too old or in the future
|
103
|
+
if not 0 < time_diff.total_seconds() < TIMESTAMP_TOLERANCE:
|
104
|
+
return _unary_unary_rpc_terminator("Invalid timestamp")
|
105
|
+
|
106
|
+
# Continue the RPC call
|
107
|
+
expected_node_id = state.get_node_id(node_pk_bytes)
|
108
|
+
if not handler_call_details.method.endswith("CreateNode"):
|
109
|
+
if expected_node_id is None:
|
110
|
+
return _unary_unary_rpc_terminator("Invalid node ID")
|
115
111
|
# One of the method handlers in
|
116
112
|
# `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
|
117
113
|
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
|
118
|
-
return self.
|
114
|
+
return self._wrap_method_handler(
|
115
|
+
method_handler, expected_node_id, node_pk_bytes
|
116
|
+
)
|
119
117
|
|
120
|
-
def
|
121
|
-
self,
|
118
|
+
def _wrap_method_handler(
|
119
|
+
self,
|
120
|
+
method_handler: grpc.RpcMethodHandler,
|
121
|
+
expected_node_id: Optional[int],
|
122
|
+
node_public_key: bytes,
|
122
123
|
) -> grpc.RpcMethodHandler:
|
123
124
|
def _generic_method_handler(
|
124
|
-
request:
|
125
|
+
request: GrpcMessage,
|
125
126
|
context: grpc.ServicerContext,
|
126
|
-
) ->
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
_get_value_from_tuples(
|
149
|
-
_AUTH_TOKEN_HEADER, context.invocation_metadata()
|
150
|
-
)
|
151
|
-
)
|
152
|
-
public_key = bytes_to_public_key(node_public_key_bytes)
|
153
|
-
|
154
|
-
if not self._verify_hmac(public_key, request, hmac_value):
|
155
|
-
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
156
|
-
|
157
|
-
# Verify node_id
|
158
|
-
node_id = self.state_factory.state().get_node_id(node_public_key_bytes)
|
159
|
-
|
160
|
-
if not self._verify_node_id(node_id, request):
|
161
|
-
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
162
|
-
|
163
|
-
return method_handler.unary_unary(request, context) # type: ignore
|
127
|
+
) -> GrpcMessage:
|
128
|
+
# Verify the node ID
|
129
|
+
if not isinstance(request, CreateNodeRequest):
|
130
|
+
try:
|
131
|
+
if request.node.node_id != expected_node_id: # type: ignore
|
132
|
+
raise ValueError
|
133
|
+
except (AttributeError, ValueError):
|
134
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
|
135
|
+
|
136
|
+
response: GrpcMessage = method_handler.unary_unary(request, context)
|
137
|
+
|
138
|
+
# Set the public key after a successful CreateNode request
|
139
|
+
if isinstance(response, CreateNodeResponse):
|
140
|
+
state = self.state_factory.state()
|
141
|
+
try:
|
142
|
+
state.set_node_public_key(response.node.node_id, node_public_key)
|
143
|
+
except ValueError as e:
|
144
|
+
# Remove newly created node if setting the public key fails
|
145
|
+
state.delete_node(response.node.node_id)
|
146
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
|
147
|
+
|
148
|
+
return response
|
164
149
|
|
165
150
|
return grpc.unary_unary_rpc_method_handler(
|
166
151
|
_generic_method_handler,
|
167
152
|
request_deserializer=method_handler.request_deserializer,
|
168
153
|
response_serializer=method_handler.response_serializer,
|
169
154
|
)
|
170
|
-
|
171
|
-
def _verify_node_id(
|
172
|
-
self,
|
173
|
-
node_id: Optional[int],
|
174
|
-
request: Union[
|
175
|
-
DeleteNodeRequest,
|
176
|
-
PullTaskInsRequest,
|
177
|
-
PushTaskResRequest,
|
178
|
-
GetRunRequest,
|
179
|
-
PingRequest,
|
180
|
-
GetFabRequest,
|
181
|
-
],
|
182
|
-
) -> bool:
|
183
|
-
if node_id is None:
|
184
|
-
return False
|
185
|
-
if isinstance(request, PushTaskResRequest):
|
186
|
-
if len(request.task_res_list) == 0:
|
187
|
-
return False
|
188
|
-
return request.task_res_list[0].task.producer.node_id == node_id
|
189
|
-
if isinstance(request, GetRunRequest):
|
190
|
-
return node_id in self.state_factory.state().get_nodes(request.run_id)
|
191
|
-
return request.node.node_id == node_id
|
192
|
-
|
193
|
-
def _verify_hmac(
|
194
|
-
self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
|
195
|
-
) -> bool:
|
196
|
-
shared_secret = generate_shared_key(self.server_private_key, public_key)
|
197
|
-
message_bytes = request.SerializeToString(deterministic=True)
|
198
|
-
return verify_hmac(shared_secret, message_bytes, hmac_value)
|
199
|
-
|
200
|
-
def _create_authenticated_node(
|
201
|
-
self,
|
202
|
-
public_key_bytes: bytes,
|
203
|
-
request: CreateNodeRequest,
|
204
|
-
context: grpc.ServicerContext,
|
205
|
-
) -> CreateNodeResponse:
|
206
|
-
context.send_initial_metadata(
|
207
|
-
(
|
208
|
-
(
|
209
|
-
_PUBLIC_KEY_HEADER,
|
210
|
-
self.encoded_server_public_key,
|
211
|
-
),
|
212
|
-
)
|
213
|
-
)
|
214
|
-
state = self.state_factory.state()
|
215
|
-
node_id = state.get_node_id(public_key_bytes)
|
216
|
-
|
217
|
-
# Handle `CreateNode` here instead of calling the default method handler
|
218
|
-
# Return previously assigned `node_id` for the provided `public_key`
|
219
|
-
if node_id is not None:
|
220
|
-
state.acknowledge_ping(node_id, request.ping_interval)
|
221
|
-
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
222
|
-
|
223
|
-
# No `node_id` exists for the provided `public_key`
|
224
|
-
# Handle `CreateNode` here instead of calling the default method handler
|
225
|
-
# Note: the innermost `CreateNode` method will never be called
|
226
|
-
node_id = state.create_node(request.ping_interval, public_key_bytes)
|
227
|
-
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
@@ -20,7 +20,14 @@ from typing import Optional
|
|
20
20
|
from uuid import UUID
|
21
21
|
|
22
22
|
from flwr.common.constant import Status
|
23
|
-
from flwr.common.serde import
|
23
|
+
from flwr.common.serde import (
|
24
|
+
fab_to_proto,
|
25
|
+
message_from_proto,
|
26
|
+
message_from_taskins,
|
27
|
+
message_to_proto,
|
28
|
+
message_to_taskres,
|
29
|
+
user_config_to_proto,
|
30
|
+
)
|
24
31
|
from flwr.common.typing import Fab, InvalidRunStatusException
|
25
32
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
26
33
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
@@ -30,8 +37,12 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
30
37
|
DeleteNodeResponse,
|
31
38
|
PingRequest,
|
32
39
|
PingResponse,
|
40
|
+
PullMessagesRequest,
|
41
|
+
PullMessagesResponse,
|
33
42
|
PullTaskInsRequest,
|
34
43
|
PullTaskInsResponse,
|
44
|
+
PushMessagesRequest,
|
45
|
+
PushMessagesResponse,
|
35
46
|
PushTaskResRequest,
|
36
47
|
PushTaskResResponse,
|
37
48
|
Reconnect,
|
@@ -55,13 +66,13 @@ def create_node(
|
|
55
66
|
"""."""
|
56
67
|
# Create node
|
57
68
|
node_id = state.create_node(ping_interval=request.ping_interval)
|
58
|
-
return CreateNodeResponse(node=Node(node_id=node_id
|
69
|
+
return CreateNodeResponse(node=Node(node_id=node_id))
|
59
70
|
|
60
71
|
|
61
72
|
def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
|
62
73
|
"""."""
|
63
74
|
# Validate node_id
|
64
|
-
if request.node.
|
75
|
+
if request.node.node_id == 0: # i.e. unset `node_id`
|
65
76
|
return DeleteNodeResponse()
|
66
77
|
|
67
78
|
# Update state
|
@@ -80,9 +91,8 @@ def ping(
|
|
80
91
|
|
81
92
|
def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
|
82
93
|
"""Pull TaskIns handler."""
|
83
|
-
# Get node_id if client node is not anonymous
|
84
94
|
node = request.node # pylint: disable=no-member
|
85
|
-
node_id:
|
95
|
+
node_id: int = node.node_id
|
86
96
|
|
87
97
|
# Retrieve TaskIns from State
|
88
98
|
task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)
|
@@ -94,6 +104,26 @@ def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsR
|
|
94
104
|
return response
|
95
105
|
|
96
106
|
|
107
|
+
def pull_messages(
|
108
|
+
request: PullMessagesRequest, state: LinkState
|
109
|
+
) -> PullMessagesResponse:
|
110
|
+
"""Pull Messages handler."""
|
111
|
+
# Get node_id if client node is not anonymous
|
112
|
+
node = request.node # pylint: disable=no-member
|
113
|
+
node_id: int = node.node_id
|
114
|
+
|
115
|
+
# Retrieve TaskIns from State
|
116
|
+
task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)
|
117
|
+
|
118
|
+
# Convert to Messages
|
119
|
+
msg_proto = []
|
120
|
+
for task_ins in task_ins_list:
|
121
|
+
msg = message_from_taskins(task_ins)
|
122
|
+
msg_proto.append(message_to_proto(msg))
|
123
|
+
|
124
|
+
return PullMessagesResponse(messages_list=msg_proto)
|
125
|
+
|
126
|
+
|
97
127
|
def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse:
|
98
128
|
"""Push TaskRes handler."""
|
99
129
|
# pylint: disable=no-member
|
@@ -123,6 +153,37 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
|
|
123
153
|
return response
|
124
154
|
|
125
155
|
|
156
|
+
def push_messages(
|
157
|
+
request: PushMessagesRequest, state: LinkState
|
158
|
+
) -> PushMessagesResponse:
|
159
|
+
"""Push Messages handler."""
|
160
|
+
# Convert Message to TaskRes
|
161
|
+
msg = message_from_proto(message_proto=request.messages_list[0])
|
162
|
+
task_res = message_to_taskres(msg)
|
163
|
+
|
164
|
+
# Abort if the run is not running
|
165
|
+
abort_msg = check_abort(
|
166
|
+
task_res.run_id,
|
167
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
168
|
+
state,
|
169
|
+
)
|
170
|
+
if abort_msg:
|
171
|
+
raise InvalidRunStatusException(abort_msg)
|
172
|
+
|
173
|
+
# Set pushed_at (timestamp in seconds)
|
174
|
+
task_res.task.pushed_at = time.time()
|
175
|
+
|
176
|
+
# Store TaskRes in State
|
177
|
+
message_id: Optional[UUID] = state.store_task_res(task_res=task_res)
|
178
|
+
|
179
|
+
# Build response
|
180
|
+
response = PushMessagesResponse(
|
181
|
+
reconnect=Reconnect(reconnect=5),
|
182
|
+
results={str(message_id): 0},
|
183
|
+
)
|
184
|
+
return response
|
185
|
+
|
186
|
+
|
126
187
|
def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
|
127
188
|
"""Get run information."""
|
128
189
|
run = state.get_run(request.run_id)
|
@@ -17,13 +17,12 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
import sys
|
21
20
|
from collections.abc import Awaitable
|
22
21
|
from typing import Callable, TypeVar, cast
|
23
22
|
|
24
23
|
from google.protobuf.message import Message as GrpcMessage
|
25
24
|
|
26
|
-
from flwr.common.
|
25
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
27
26
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
28
27
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
29
28
|
CreateNodeRequest,
|
@@ -32,8 +31,12 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
32
31
|
DeleteNodeResponse,
|
33
32
|
PingRequest,
|
34
33
|
PingResponse,
|
34
|
+
PullMessagesRequest,
|
35
|
+
PullMessagesResponse,
|
35
36
|
PullTaskInsRequest,
|
36
37
|
PullTaskInsResponse,
|
38
|
+
PushMessagesRequest,
|
39
|
+
PushMessagesResponse,
|
37
40
|
PushTaskResRequest,
|
38
41
|
PushTaskResResponse,
|
39
42
|
)
|
@@ -51,7 +54,7 @@ try:
|
|
51
54
|
from starlette.responses import Response
|
52
55
|
from starlette.routing import Route
|
53
56
|
except ModuleNotFoundError:
|
54
|
-
|
57
|
+
flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
|
55
58
|
|
56
59
|
|
57
60
|
GrpcRequest = TypeVar("GrpcRequest", bound=GrpcMessage)
|
@@ -117,6 +120,16 @@ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
117
120
|
return message_handler.pull_task_ins(request=request, state=state)
|
118
121
|
|
119
122
|
|
123
|
+
@rest_request_response(PullMessagesRequest)
|
124
|
+
async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
|
125
|
+
"""Pull PullMessages."""
|
126
|
+
# Get state from app
|
127
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
128
|
+
|
129
|
+
# Handle message
|
130
|
+
return message_handler.pull_messages(request=request, state=state)
|
131
|
+
|
132
|
+
|
120
133
|
# Check if token is needed here
|
121
134
|
@rest_request_response(PushTaskResRequest)
|
122
135
|
async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
@@ -128,6 +141,16 @@ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
|
|
128
141
|
return message_handler.push_task_res(request=request, state=state)
|
129
142
|
|
130
143
|
|
144
|
+
@rest_request_response(PushMessagesRequest)
|
145
|
+
async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
|
146
|
+
"""Pull PushMessages."""
|
147
|
+
# Get state from app
|
148
|
+
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
|
149
|
+
|
150
|
+
# Handle message
|
151
|
+
return message_handler.push_messages(request=request, state=state)
|
152
|
+
|
153
|
+
|
131
154
|
@rest_request_response(PingRequest)
|
132
155
|
async def ping(request: PingRequest) -> PingResponse:
|
133
156
|
"""Ping."""
|
@@ -165,7 +188,9 @@ routes = [
|
|
165
188
|
Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
|
166
189
|
Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
|
167
190
|
Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
|
191
|
+
Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]),
|
168
192
|
Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
|
193
|
+
Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]),
|
169
194
|
Route("/api/v0/fleet/ping", ping, methods=["POST"]),
|
170
195
|
Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
|
171
196
|
Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
|
@@ -182,8 +182,8 @@ def run_api(
|
|
182
182
|
f_stop: threading.Event,
|
183
183
|
) -> None:
|
184
184
|
"""Run the VCE."""
|
185
|
-
taskins_queue:
|
186
|
-
taskres_queue:
|
185
|
+
taskins_queue: Queue[TaskIns] = Queue()
|
186
|
+
taskres_queue: Queue[TaskRes] = Queue()
|
187
187
|
|
188
188
|
try:
|
189
189
|
|