flwr-nightly 1.15.0.dev20250114__py3-none-any.whl → 1.15.0.dev20250123__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/config_utils.py +23 -146
- flwr/cli/constant.py +27 -0
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +17 -2
- flwr/cli/login/login.py +9 -1
- flwr/cli/ls.py +10 -2
- flwr/cli/run/run.py +20 -10
- flwr/cli/stop.py +9 -1
- flwr/client/app.py +23 -43
- flwr/client/clientapp/app.py +4 -6
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/grpc_client/connection.py +0 -6
- flwr/client/grpc_rere_client/client_interceptor.py +19 -125
- flwr/client/grpc_rere_client/connection.py +10 -0
- flwr/client/rest_client/connection.py +12 -3
- flwr/client/supernode/app.py +14 -20
- flwr/common/auth_plugin/auth_plugin.py +1 -0
- flwr/common/config.py +152 -15
- flwr/common/constant.py +9 -8
- flwr/common/exit/__init__.py +24 -0
- flwr/common/exit/exit.py +99 -0
- flwr/common/exit/exit_code.py +93 -0
- flwr/common/exit_handlers.py +24 -10
- flwr/common/grpc.py +7 -0
- flwr/common/logger.py +1 -1
- flwr/common/serde.py +6 -4
- flwr/proto/clientappio_pb2.py +13 -3
- flwr/proto/clientappio_pb2_grpc.py +63 -12
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/exec_pb2.py +15 -5
- flwr/proto/exec_pb2_grpc.py +105 -24
- flwr/proto/fab_pb2.py +13 -3
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fleet_pb2.py +15 -5
- flwr/proto/fleet_pb2_grpc.py +147 -36
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +1 -4
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/recordset_pb2.py +18 -8
- flwr/proto/recordset_pb2_grpc.py +20 -0
- flwr/proto/run_pb2.py +16 -6
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/serverappio_pb2.py +32 -14
- flwr/proto/serverappio_pb2.pyi +56 -0
- flwr/proto/serverappio_pb2_grpc.py +261 -44
- flwr/proto/serverappio_pb2_grpc.pyi +20 -0
- flwr/proto/simulationio_pb2.py +13 -3
- flwr/proto/simulationio_pb2_grpc.py +105 -24
- flwr/proto/task_pb2.py +13 -3
- flwr/proto/task_pb2_grpc.py +20 -0
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/server/app.py +40 -11
- flwr/server/compat/app_utils.py +0 -1
- flwr/server/compat/driver_client_proxy.py +1 -2
- flwr/server/driver/grpc_driver.py +5 -2
- flwr/server/driver/inmemory_driver.py +2 -1
- flwr/server/serverapp/app.py +5 -6
- flwr/server/superlink/driver/serverappio_servicer.py +110 -6
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +95 -169
- flwr/server/superlink/fleet/message_handler/message_handler.py +4 -5
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -26
- flwr/server/superlink/linkstate/linkstate.py +5 -18
- flwr/server/superlink/linkstate/sqlite_linkstate.py +30 -70
- flwr/server/superlink/linkstate/utils.py +18 -8
- flwr/server/utils/validator.py +9 -34
- flwr/simulation/app.py +4 -6
- flwr/simulation/legacy_app.py +4 -2
- {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/METADATA +4 -4
- {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/RECORD +82 -78
- {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/entry_points.txt +0 -0
@@ -16,14 +16,13 @@
|
|
16
16
|
|
17
17
|
|
18
18
|
import threading
|
19
|
-
import time
|
20
19
|
from logging import DEBUG, INFO
|
21
20
|
from typing import Optional
|
22
21
|
from uuid import UUID
|
23
22
|
|
24
23
|
import grpc
|
25
24
|
|
26
|
-
from flwr.common import ConfigsRecord
|
25
|
+
from flwr.common import ConfigsRecord, now
|
27
26
|
from flwr.common.constant import Status
|
28
27
|
from flwr.common.logger import log
|
29
28
|
from flwr.common.serde import (
|
@@ -31,6 +30,10 @@ from flwr.common.serde import (
|
|
31
30
|
context_to_proto,
|
32
31
|
fab_from_proto,
|
33
32
|
fab_to_proto,
|
33
|
+
message_from_proto,
|
34
|
+
message_from_taskres,
|
35
|
+
message_to_proto,
|
36
|
+
message_to_taskins,
|
34
37
|
run_status_from_proto,
|
35
38
|
run_status_to_proto,
|
36
39
|
run_to_proto,
|
@@ -57,10 +60,14 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
57
60
|
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
58
61
|
GetNodesRequest,
|
59
62
|
GetNodesResponse,
|
63
|
+
PullResMessagesRequest,
|
64
|
+
PullResMessagesResponse,
|
60
65
|
PullServerAppInputsRequest,
|
61
66
|
PullServerAppInputsResponse,
|
62
67
|
PullTaskResRequest,
|
63
68
|
PullTaskResResponse,
|
69
|
+
PushInsMessagesRequest,
|
70
|
+
PushInsMessagesResponse,
|
64
71
|
PushServerAppOutputsRequest,
|
65
72
|
PushServerAppOutputsResponse,
|
66
73
|
PushTaskInsRequest,
|
@@ -102,9 +109,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
102
109
|
)
|
103
110
|
|
104
111
|
all_ids: set[int] = state.get_nodes(request.run_id)
|
105
|
-
nodes: list[Node] = [
|
106
|
-
Node(node_id=node_id, anonymous=False) for node_id in all_ids
|
107
|
-
]
|
112
|
+
nodes: list[Node] = [Node(node_id=node_id) for node_id in all_ids]
|
108
113
|
return GetNodesResponse(nodes=nodes)
|
109
114
|
|
110
115
|
def CreateRun(
|
@@ -151,7 +156,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
151
156
|
)
|
152
157
|
|
153
158
|
# Set pushed_at (timestamp in seconds)
|
154
|
-
pushed_at =
|
159
|
+
pushed_at = now().timestamp()
|
155
160
|
for task_ins in request.task_ins_list:
|
156
161
|
task_ins.task.pushed_at = pushed_at
|
157
162
|
|
@@ -184,6 +189,59 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
184
189
|
task_ids=[str(task_id) if task_id else "" for task_id in task_ids]
|
185
190
|
)
|
186
191
|
|
192
|
+
def PushMessages(
|
193
|
+
self, request: PushInsMessagesRequest, context: grpc.ServicerContext
|
194
|
+
) -> PushInsMessagesResponse:
|
195
|
+
"""Push a set of Messages."""
|
196
|
+
log(DEBUG, "ServerAppIoServicer.PushMessages")
|
197
|
+
|
198
|
+
# Init state
|
199
|
+
state: LinkState = self.state_factory.state()
|
200
|
+
|
201
|
+
# Abort if the run is not running
|
202
|
+
abort_if(
|
203
|
+
request.run_id,
|
204
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
205
|
+
state,
|
206
|
+
context,
|
207
|
+
)
|
208
|
+
|
209
|
+
# Set pushed_at (timestamp in seconds)
|
210
|
+
pushed_at = now().timestamp()
|
211
|
+
|
212
|
+
# Validate request and insert in State
|
213
|
+
_raise_if(
|
214
|
+
validation_error=len(request.messages_list) == 0,
|
215
|
+
request_name="PushMessages",
|
216
|
+
detail="`messages_list` must not be empty",
|
217
|
+
)
|
218
|
+
message_ids: list[Optional[UUID]] = []
|
219
|
+
while request.messages_list:
|
220
|
+
message_proto = request.messages_list.pop(0)
|
221
|
+
message = message_from_proto(message_proto=message_proto)
|
222
|
+
task_ins = message_to_taskins(message=message)
|
223
|
+
task_ins.task.pushed_at = pushed_at
|
224
|
+
validation_errors = validate_task_ins_or_res(task_ins)
|
225
|
+
_raise_if(
|
226
|
+
validation_error=bool(validation_errors),
|
227
|
+
request_name="PushMessages",
|
228
|
+
detail=", ".join(validation_errors),
|
229
|
+
)
|
230
|
+
_raise_if(
|
231
|
+
validation_error=request.run_id != task_ins.run_id,
|
232
|
+
request_name="PushMessages",
|
233
|
+
detail="`task_ins` has mismatched `run_id`",
|
234
|
+
)
|
235
|
+
# Store
|
236
|
+
message_id: Optional[UUID] = state.store_task_ins(task_ins=task_ins)
|
237
|
+
message_ids.append(message_id)
|
238
|
+
|
239
|
+
return PushInsMessagesResponse(
|
240
|
+
message_ids=[
|
241
|
+
str(message_id) if message_id else "" for message_id in message_ids
|
242
|
+
]
|
243
|
+
)
|
244
|
+
|
187
245
|
def PullTaskRes(
|
188
246
|
self, request: PullTaskResRequest, context: grpc.ServicerContext
|
189
247
|
) -> PullTaskResResponse:
|
@@ -223,6 +281,52 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
223
281
|
|
224
282
|
return PullTaskResResponse(task_res_list=task_res_list)
|
225
283
|
|
284
|
+
def PullMessages(
|
285
|
+
self, request: PullResMessagesRequest, context: grpc.ServicerContext
|
286
|
+
) -> PullResMessagesResponse:
|
287
|
+
"""Pull a set of Messages."""
|
288
|
+
log(DEBUG, "ServerAppIoServicer.PullMessages")
|
289
|
+
|
290
|
+
# Init state
|
291
|
+
state: LinkState = self.state_factory.state()
|
292
|
+
|
293
|
+
# Abort if the run is not running
|
294
|
+
abort_if(
|
295
|
+
request.run_id,
|
296
|
+
[Status.PENDING, Status.STARTING, Status.FINISHED],
|
297
|
+
state,
|
298
|
+
context,
|
299
|
+
)
|
300
|
+
|
301
|
+
# Convert each task_id str to UUID
|
302
|
+
message_ids: set[UUID] = {
|
303
|
+
UUID(message_id) for message_id in request.message_ids
|
304
|
+
}
|
305
|
+
|
306
|
+
# Read from state
|
307
|
+
task_res_list: list[TaskRes] = state.get_task_res(task_ids=message_ids)
|
308
|
+
|
309
|
+
# Convert to Messages
|
310
|
+
messages_list = []
|
311
|
+
while task_res_list:
|
312
|
+
task_res = task_res_list.pop(0)
|
313
|
+
_raise_if(
|
314
|
+
validation_error=request.run_id != task_res.run_id,
|
315
|
+
request_name="PullMessages",
|
316
|
+
detail="`task_res` has mismatched `run_id`",
|
317
|
+
)
|
318
|
+
message = message_from_taskres(taskres=task_res)
|
319
|
+
messages_list.append(message_to_proto(message))
|
320
|
+
|
321
|
+
# Delete the TaskIns/TaskRes pairs if TaskRes is found
|
322
|
+
task_ins_ids_to_delete = {
|
323
|
+
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
|
324
|
+
}
|
325
|
+
|
326
|
+
state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
|
327
|
+
|
328
|
+
return PullResMessagesResponse(messages_list=messages_list)
|
329
|
+
|
226
330
|
def GetRun(
|
227
331
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
228
332
|
) -> GetRunResponse:
|
@@ -15,7 +15,7 @@
|
|
15
15
|
"""Fleet API gRPC adapter servicer."""
|
16
16
|
|
17
17
|
|
18
|
-
from logging import DEBUG
|
18
|
+
from logging import DEBUG
|
19
19
|
from typing import Callable, TypeVar
|
20
20
|
|
21
21
|
import grpc
|
@@ -31,35 +31,30 @@ from flwr.common.constant import (
|
|
31
31
|
from flwr.common.logger import log
|
32
32
|
from flwr.common.version import package_name, package_version
|
33
33
|
from flwr.proto import grpcadapter_pb2_grpc # pylint: disable=E0611
|
34
|
-
from flwr.proto.fab_pb2 import GetFabRequest
|
34
|
+
from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
|
35
35
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
36
36
|
CreateNodeRequest,
|
37
|
-
CreateNodeResponse,
|
38
37
|
DeleteNodeRequest,
|
39
|
-
DeleteNodeResponse,
|
40
38
|
PingRequest,
|
41
|
-
|
42
|
-
|
43
|
-
PullTaskInsResponse,
|
44
|
-
PushTaskResRequest,
|
45
|
-
PushTaskResResponse,
|
39
|
+
PullMessagesRequest,
|
40
|
+
PushMessagesRequest,
|
46
41
|
)
|
47
42
|
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
48
|
-
from flwr.proto.run_pb2 import GetRunRequest
|
49
|
-
|
50
|
-
from
|
51
|
-
from flwr.server.superlink.linkstate import LinkStateFactory
|
43
|
+
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
|
44
|
+
|
45
|
+
from ..grpc_rere.fleet_servicer import FleetServicer
|
52
46
|
|
53
47
|
T = TypeVar("T", bound=GrpcMessage)
|
54
48
|
|
55
49
|
|
56
50
|
def _handle(
|
57
51
|
msg_container: MessageContainer,
|
52
|
+
context: grpc.ServicerContext,
|
58
53
|
request_type: type[T],
|
59
|
-
handler: Callable[[T], GrpcMessage],
|
54
|
+
handler: Callable[[T, grpc.ServicerContext], GrpcMessage],
|
60
55
|
) -> MessageContainer:
|
61
56
|
req = request_type.FromString(msg_container.grpc_message_content)
|
62
|
-
res = handler(req)
|
57
|
+
res = handler(req, context)
|
63
58
|
res_cls = res.__class__
|
64
59
|
return MessageContainer(
|
65
60
|
metadata={
|
@@ -74,89 +69,26 @@ def _handle(
|
|
74
69
|
)
|
75
70
|
|
76
71
|
|
77
|
-
class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
|
72
|
+
class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer, FleetServicer):
|
78
73
|
"""Fleet API via GrpcAdapter servicer."""
|
79
74
|
|
80
|
-
def __init__(
|
81
|
-
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
|
82
|
-
) -> None:
|
83
|
-
self.state_factory = state_factory
|
84
|
-
self.ffs_factory = ffs_factory
|
85
|
-
|
86
75
|
def SendReceive( # pylint: disable=too-many-return-statements
|
87
76
|
self, request: MessageContainer, context: grpc.ServicerContext
|
88
77
|
) -> MessageContainer:
|
89
78
|
"""."""
|
90
79
|
log(DEBUG, "GrpcAdapterServicer.SendReceive")
|
91
80
|
if request.grpc_message_name == CreateNodeRequest.__qualname__:
|
92
|
-
return _handle(request, CreateNodeRequest, self.
|
81
|
+
return _handle(request, context, CreateNodeRequest, self.CreateNode)
|
93
82
|
if request.grpc_message_name == DeleteNodeRequest.__qualname__:
|
94
|
-
return _handle(request, DeleteNodeRequest, self.
|
83
|
+
return _handle(request, context, DeleteNodeRequest, self.DeleteNode)
|
95
84
|
if request.grpc_message_name == PingRequest.__qualname__:
|
96
|
-
return _handle(request, PingRequest, self.
|
97
|
-
if request.grpc_message_name == PullTaskInsRequest.__qualname__:
|
98
|
-
return _handle(request, PullTaskInsRequest, self._pull_task_ins)
|
99
|
-
if request.grpc_message_name == PushTaskResRequest.__qualname__:
|
100
|
-
return _handle(request, PushTaskResRequest, self._push_task_res)
|
85
|
+
return _handle(request, context, PingRequest, self.Ping)
|
101
86
|
if request.grpc_message_name == GetRunRequest.__qualname__:
|
102
|
-
return _handle(request, GetRunRequest, self.
|
87
|
+
return _handle(request, context, GetRunRequest, self.GetRun)
|
103
88
|
if request.grpc_message_name == GetFabRequest.__qualname__:
|
104
|
-
return _handle(request, GetFabRequest, self.
|
89
|
+
return _handle(request, context, GetFabRequest, self.GetFab)
|
90
|
+
if request.grpc_message_name == PullMessagesRequest.__qualname__:
|
91
|
+
return _handle(request, context, PullMessagesRequest, self.PullMessages)
|
92
|
+
if request.grpc_message_name == PushMessagesRequest.__qualname__:
|
93
|
+
return _handle(request, context, PushMessagesRequest, self.PushMessages)
|
105
94
|
raise ValueError(f"Invalid grpc_message_name: {request.grpc_message_name}")
|
106
|
-
|
107
|
-
def _create_node(self, request: CreateNodeRequest) -> CreateNodeResponse:
|
108
|
-
"""."""
|
109
|
-
log(INFO, "GrpcAdapter.CreateNode")
|
110
|
-
return message_handler.create_node(
|
111
|
-
request=request,
|
112
|
-
state=self.state_factory.state(),
|
113
|
-
)
|
114
|
-
|
115
|
-
def _delete_node(self, request: DeleteNodeRequest) -> DeleteNodeResponse:
|
116
|
-
"""."""
|
117
|
-
log(INFO, "GrpcAdapter.DeleteNode")
|
118
|
-
return message_handler.delete_node(
|
119
|
-
request=request,
|
120
|
-
state=self.state_factory.state(),
|
121
|
-
)
|
122
|
-
|
123
|
-
def _ping(self, request: PingRequest) -> PingResponse:
|
124
|
-
"""."""
|
125
|
-
log(DEBUG, "GrpcAdapter.Ping")
|
126
|
-
return message_handler.ping(
|
127
|
-
request=request,
|
128
|
-
state=self.state_factory.state(),
|
129
|
-
)
|
130
|
-
|
131
|
-
def _pull_task_ins(self, request: PullTaskInsRequest) -> PullTaskInsResponse:
|
132
|
-
"""Pull TaskIns."""
|
133
|
-
log(INFO, "GrpcAdapter.PullTaskIns")
|
134
|
-
return message_handler.pull_task_ins(
|
135
|
-
request=request,
|
136
|
-
state=self.state_factory.state(),
|
137
|
-
)
|
138
|
-
|
139
|
-
def _push_task_res(self, request: PushTaskResRequest) -> PushTaskResResponse:
|
140
|
-
"""Push TaskRes."""
|
141
|
-
log(INFO, "GrpcAdapter.PushTaskRes")
|
142
|
-
return message_handler.push_task_res(
|
143
|
-
request=request,
|
144
|
-
state=self.state_factory.state(),
|
145
|
-
)
|
146
|
-
|
147
|
-
def _get_run(self, request: GetRunRequest) -> GetRunResponse:
|
148
|
-
"""Get run information."""
|
149
|
-
log(INFO, "GrpcAdapter.GetRun")
|
150
|
-
return message_handler.get_run(
|
151
|
-
request=request,
|
152
|
-
state=self.state_factory.state(),
|
153
|
-
)
|
154
|
-
|
155
|
-
def _get_fab(self, request: GetFabRequest) -> GetFabResponse:
|
156
|
-
"""Get FAB."""
|
157
|
-
log(INFO, "GrpcAdapter.GetFab")
|
158
|
-
return message_handler.get_fab(
|
159
|
-
request=request,
|
160
|
-
ffs=self.ffs_factory.ffs(),
|
161
|
-
state=self.state_factory.state(),
|
162
|
-
)
|
@@ -15,91 +15,54 @@
|
|
15
15
|
"""Flower server interceptor."""
|
16
16
|
|
17
17
|
|
18
|
-
import
|
19
|
-
from
|
20
|
-
from logging import INFO, WARNING
|
21
|
-
from typing import Any, Callable, Optional, Union
|
18
|
+
import datetime
|
19
|
+
from typing import Any, Callable, Optional, cast
|
22
20
|
|
23
21
|
import grpc
|
24
|
-
from
|
25
|
-
|
26
|
-
from flwr.common
|
22
|
+
from google.protobuf.message import Message as GrpcMessage
|
23
|
+
|
24
|
+
from flwr.common import now
|
25
|
+
from flwr.common.constant import (
|
26
|
+
PUBLIC_KEY_HEADER,
|
27
|
+
SIGNATURE_HEADER,
|
28
|
+
TIMESTAMP_HEADER,
|
29
|
+
TIMESTAMP_TOLERANCE,
|
30
|
+
)
|
27
31
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
28
|
-
bytes_to_private_key,
|
29
32
|
bytes_to_public_key,
|
30
|
-
|
31
|
-
verify_hmac,
|
33
|
+
verify_signature,
|
32
34
|
)
|
33
|
-
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
34
35
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
35
36
|
CreateNodeRequest,
|
36
37
|
CreateNodeResponse,
|
37
|
-
DeleteNodeRequest,
|
38
|
-
DeleteNodeResponse,
|
39
|
-
PingRequest,
|
40
|
-
PingResponse,
|
41
|
-
PullTaskInsRequest,
|
42
|
-
PullTaskInsResponse,
|
43
|
-
PushTaskResRequest,
|
44
|
-
PushTaskResResponse,
|
45
38
|
)
|
46
|
-
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
47
|
-
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
48
39
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
49
40
|
|
50
|
-
_PUBLIC_KEY_HEADER = "public-key"
|
51
|
-
_AUTH_TOKEN_HEADER = "auth-token"
|
52
|
-
|
53
|
-
Request = Union[
|
54
|
-
CreateNodeRequest,
|
55
|
-
DeleteNodeRequest,
|
56
|
-
PullTaskInsRequest,
|
57
|
-
PushTaskResRequest,
|
58
|
-
GetRunRequest,
|
59
|
-
PingRequest,
|
60
|
-
GetFabRequest,
|
61
|
-
]
|
62
|
-
|
63
|
-
Response = Union[
|
64
|
-
CreateNodeResponse,
|
65
|
-
DeleteNodeResponse,
|
66
|
-
PullTaskInsResponse,
|
67
|
-
PushTaskResResponse,
|
68
|
-
GetRunResponse,
|
69
|
-
PingResponse,
|
70
|
-
GetFabResponse,
|
71
|
-
]
|
72
|
-
|
73
41
|
|
74
|
-
def
|
75
|
-
|
76
|
-
)
|
77
|
-
|
78
|
-
if isinstance(value, str):
|
79
|
-
return value.encode()
|
42
|
+
def _unary_unary_rpc_terminator(message: str) -> grpc.RpcMethodHandler:
|
43
|
+
def terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMessage:
|
44
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
|
45
|
+
raise RuntimeError("Should not reach this point") # Make mypy happy
|
80
46
|
|
81
|
-
return
|
47
|
+
return grpc.unary_unary_rpc_method_handler(terminate)
|
82
48
|
|
83
49
|
|
84
50
|
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
85
|
-
"""Server interceptor for node authentication.
|
86
|
-
|
87
|
-
|
51
|
+
"""Server interceptor for node authentication.
|
52
|
+
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
state_factory : LinkStateFactory
|
56
|
+
A factory for creating new instances of LinkState.
|
57
|
+
auto_auth : bool (default: False)
|
58
|
+
If True, nodes are authenticated without requiring their public keys to be
|
59
|
+
pre-stored in the LinkState. If False, only nodes with pre-stored public keys
|
60
|
+
can be authenticated.
|
61
|
+
"""
|
62
|
+
|
63
|
+
def __init__(self, state_factory: LinkStateFactory, auto_auth: bool = False):
|
88
64
|
self.state_factory = state_factory
|
89
|
-
|
90
|
-
|
91
|
-
self.node_public_keys = state.get_node_public_keys()
|
92
|
-
if len(self.node_public_keys) == 0:
|
93
|
-
log(WARNING, "Authentication enabled, but no known public keys configured")
|
94
|
-
|
95
|
-
private_key = state.get_server_private_key()
|
96
|
-
public_key = state.get_server_public_key()
|
97
|
-
|
98
|
-
if private_key is None or public_key is None:
|
99
|
-
raise ValueError("Error loading authentication keys")
|
100
|
-
|
101
|
-
self.server_private_key = bytes_to_private_key(private_key)
|
102
|
-
self.encoded_server_public_key = base64.urlsafe_b64encode(public_key)
|
65
|
+
self.auto_auth = auto_auth
|
103
66
|
|
104
67
|
def intercept_service(
|
105
68
|
self,
|
@@ -112,117 +75,80 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
112
75
|
metadata sent by the node. Continue RPC call if node is authenticated, else,
|
113
76
|
terminate RPC call by setting context to abort.
|
114
77
|
"""
|
78
|
+
state = self.state_factory.state()
|
79
|
+
metadata_dict = dict(handler_call_details.invocation_metadata)
|
80
|
+
|
81
|
+
# Retrieve info from the metadata
|
82
|
+
try:
|
83
|
+
node_pk_bytes = cast(bytes, metadata_dict[PUBLIC_KEY_HEADER])
|
84
|
+
timestamp_iso = cast(str, metadata_dict[TIMESTAMP_HEADER])
|
85
|
+
signature = cast(bytes, metadata_dict[SIGNATURE_HEADER])
|
86
|
+
except KeyError:
|
87
|
+
return _unary_unary_rpc_terminator("Missing authentication metadata")
|
88
|
+
|
89
|
+
if not self.auto_auth:
|
90
|
+
# Abort the RPC call if the node public key is not found
|
91
|
+
if node_pk_bytes not in state.get_node_public_keys():
|
92
|
+
return _unary_unary_rpc_terminator("Public key not recognized")
|
93
|
+
|
94
|
+
# Verify the signature
|
95
|
+
node_pk = bytes_to_public_key(node_pk_bytes)
|
96
|
+
if not verify_signature(node_pk, timestamp_iso.encode("ascii"), signature):
|
97
|
+
return _unary_unary_rpc_terminator("Invalid signature")
|
98
|
+
|
99
|
+
# Verify the timestamp
|
100
|
+
current = now()
|
101
|
+
time_diff = current - datetime.datetime.fromisoformat(timestamp_iso)
|
102
|
+
# Abort the RPC call if the timestamp is too old or in the future
|
103
|
+
if not 0 < time_diff.total_seconds() < TIMESTAMP_TOLERANCE:
|
104
|
+
return _unary_unary_rpc_terminator("Invalid timestamp")
|
105
|
+
|
106
|
+
# Continue the RPC call
|
107
|
+
expected_node_id = state.get_node_id(node_pk_bytes)
|
108
|
+
if not handler_call_details.method.endswith("CreateNode"):
|
109
|
+
if expected_node_id is None:
|
110
|
+
return _unary_unary_rpc_terminator("Invalid node ID")
|
115
111
|
# One of the method handlers in
|
116
112
|
# `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
|
117
113
|
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
|
118
|
-
return self.
|
114
|
+
return self._wrap_method_handler(
|
115
|
+
method_handler, expected_node_id, node_pk_bytes
|
116
|
+
)
|
119
117
|
|
120
|
-
def
|
121
|
-
self,
|
118
|
+
def _wrap_method_handler(
|
119
|
+
self,
|
120
|
+
method_handler: grpc.RpcMethodHandler,
|
121
|
+
expected_node_id: Optional[int],
|
122
|
+
node_public_key: bytes,
|
122
123
|
) -> grpc.RpcMethodHandler:
|
123
124
|
def _generic_method_handler(
|
124
|
-
request:
|
125
|
+
request: GrpcMessage,
|
125
126
|
context: grpc.ServicerContext,
|
126
|
-
) ->
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
_get_value_from_tuples(
|
149
|
-
_AUTH_TOKEN_HEADER, context.invocation_metadata()
|
150
|
-
)
|
151
|
-
)
|
152
|
-
public_key = bytes_to_public_key(node_public_key_bytes)
|
153
|
-
|
154
|
-
if not self._verify_hmac(public_key, request, hmac_value):
|
155
|
-
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
156
|
-
|
157
|
-
# Verify node_id
|
158
|
-
node_id = self.state_factory.state().get_node_id(node_public_key_bytes)
|
159
|
-
|
160
|
-
if not self._verify_node_id(node_id, request):
|
161
|
-
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
|
162
|
-
|
163
|
-
return method_handler.unary_unary(request, context) # type: ignore
|
127
|
+
) -> GrpcMessage:
|
128
|
+
# Verify the node ID
|
129
|
+
if not isinstance(request, CreateNodeRequest):
|
130
|
+
try:
|
131
|
+
if request.node.node_id != expected_node_id: # type: ignore
|
132
|
+
raise ValueError
|
133
|
+
except (AttributeError, ValueError):
|
134
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
|
135
|
+
|
136
|
+
response: GrpcMessage = method_handler.unary_unary(request, context)
|
137
|
+
|
138
|
+
# Set the public key after a successful CreateNode request
|
139
|
+
if isinstance(response, CreateNodeResponse):
|
140
|
+
state = self.state_factory.state()
|
141
|
+
try:
|
142
|
+
state.set_node_public_key(response.node.node_id, node_public_key)
|
143
|
+
except ValueError as e:
|
144
|
+
# Remove newly created node if setting the public key fails
|
145
|
+
state.delete_node(response.node.node_id)
|
146
|
+
context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
|
147
|
+
|
148
|
+
return response
|
164
149
|
|
165
150
|
return grpc.unary_unary_rpc_method_handler(
|
166
151
|
_generic_method_handler,
|
167
152
|
request_deserializer=method_handler.request_deserializer,
|
168
153
|
response_serializer=method_handler.response_serializer,
|
169
154
|
)
|
170
|
-
|
171
|
-
def _verify_node_id(
|
172
|
-
self,
|
173
|
-
node_id: Optional[int],
|
174
|
-
request: Union[
|
175
|
-
DeleteNodeRequest,
|
176
|
-
PullTaskInsRequest,
|
177
|
-
PushTaskResRequest,
|
178
|
-
GetRunRequest,
|
179
|
-
PingRequest,
|
180
|
-
GetFabRequest,
|
181
|
-
],
|
182
|
-
) -> bool:
|
183
|
-
if node_id is None:
|
184
|
-
return False
|
185
|
-
if isinstance(request, PushTaskResRequest):
|
186
|
-
if len(request.task_res_list) == 0:
|
187
|
-
return False
|
188
|
-
return request.task_res_list[0].task.producer.node_id == node_id
|
189
|
-
if isinstance(request, GetRunRequest):
|
190
|
-
return node_id in self.state_factory.state().get_nodes(request.run_id)
|
191
|
-
return request.node.node_id == node_id
|
192
|
-
|
193
|
-
def _verify_hmac(
|
194
|
-
self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
|
195
|
-
) -> bool:
|
196
|
-
shared_secret = generate_shared_key(self.server_private_key, public_key)
|
197
|
-
message_bytes = request.SerializeToString(deterministic=True)
|
198
|
-
return verify_hmac(shared_secret, message_bytes, hmac_value)
|
199
|
-
|
200
|
-
def _create_authenticated_node(
|
201
|
-
self,
|
202
|
-
public_key_bytes: bytes,
|
203
|
-
request: CreateNodeRequest,
|
204
|
-
context: grpc.ServicerContext,
|
205
|
-
) -> CreateNodeResponse:
|
206
|
-
context.send_initial_metadata(
|
207
|
-
(
|
208
|
-
(
|
209
|
-
_PUBLIC_KEY_HEADER,
|
210
|
-
self.encoded_server_public_key,
|
211
|
-
),
|
212
|
-
)
|
213
|
-
)
|
214
|
-
state = self.state_factory.state()
|
215
|
-
node_id = state.get_node_id(public_key_bytes)
|
216
|
-
|
217
|
-
# Handle `CreateNode` here instead of calling the default method handler
|
218
|
-
# Return previously assigned `node_id` for the provided `public_key`
|
219
|
-
if node_id is not None:
|
220
|
-
state.acknowledge_ping(node_id, request.ping_interval)
|
221
|
-
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
222
|
-
|
223
|
-
# No `node_id` exists for the provided `public_key`
|
224
|
-
# Handle `CreateNode` here instead of calling the default method handler
|
225
|
-
# Note: the innermost `CreateNode` method will never be called
|
226
|
-
node_id = state.create_node(request.ping_interval)
|
227
|
-
state.set_node_public_key(node_id, public_key_bytes)
|
228
|
-
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
|
@@ -66,13 +66,13 @@ def create_node(
|
|
66
66
|
"""."""
|
67
67
|
# Create node
|
68
68
|
node_id = state.create_node(ping_interval=request.ping_interval)
|
69
|
-
return CreateNodeResponse(node=Node(node_id=node_id
|
69
|
+
return CreateNodeResponse(node=Node(node_id=node_id))
|
70
70
|
|
71
71
|
|
72
72
|
def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
|
73
73
|
"""."""
|
74
74
|
# Validate node_id
|
75
|
-
if request.node.
|
75
|
+
if request.node.node_id == 0: # i.e. unset `node_id`
|
76
76
|
return DeleteNodeResponse()
|
77
77
|
|
78
78
|
# Update state
|
@@ -91,9 +91,8 @@ def ping(
|
|
91
91
|
|
92
92
|
def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
|
93
93
|
"""Pull TaskIns handler."""
|
94
|
-
# Get node_id if client node is not anonymous
|
95
94
|
node = request.node # pylint: disable=no-member
|
96
|
-
node_id:
|
95
|
+
node_id: int = node.node_id
|
97
96
|
|
98
97
|
# Retrieve TaskIns from State
|
99
98
|
task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)
|
@@ -111,7 +110,7 @@ def pull_messages(
|
|
111
110
|
"""Pull Messages handler."""
|
112
111
|
# Get node_id if client node is not anonymous
|
113
112
|
node = request.node # pylint: disable=no-member
|
114
|
-
node_id:
|
113
|
+
node_id: int = node.node_id
|
115
114
|
|
116
115
|
# Retrieve TaskIns from State
|
117
116
|
task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)
|