flwr-nightly 1.15.0.dev20250114__py3-none-any.whl → 1.15.0.dev20250123__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. flwr/cli/config_utils.py +23 -146
  2. flwr/cli/constant.py +27 -0
  3. flwr/cli/install.py +1 -1
  4. flwr/cli/log.py +17 -2
  5. flwr/cli/login/login.py +9 -1
  6. flwr/cli/ls.py +10 -2
  7. flwr/cli/run/run.py +20 -10
  8. flwr/cli/stop.py +9 -1
  9. flwr/client/app.py +23 -43
  10. flwr/client/clientapp/app.py +4 -6
  11. flwr/client/clientapp/utils.py +1 -1
  12. flwr/client/grpc_client/connection.py +0 -6
  13. flwr/client/grpc_rere_client/client_interceptor.py +19 -125
  14. flwr/client/grpc_rere_client/connection.py +10 -0
  15. flwr/client/rest_client/connection.py +12 -3
  16. flwr/client/supernode/app.py +14 -20
  17. flwr/common/auth_plugin/auth_plugin.py +1 -0
  18. flwr/common/config.py +152 -15
  19. flwr/common/constant.py +9 -8
  20. flwr/common/exit/__init__.py +24 -0
  21. flwr/common/exit/exit.py +99 -0
  22. flwr/common/exit/exit_code.py +93 -0
  23. flwr/common/exit_handlers.py +24 -10
  24. flwr/common/grpc.py +7 -0
  25. flwr/common/logger.py +1 -1
  26. flwr/common/serde.py +6 -4
  27. flwr/proto/clientappio_pb2.py +13 -3
  28. flwr/proto/clientappio_pb2_grpc.py +63 -12
  29. flwr/proto/error_pb2.py +13 -3
  30. flwr/proto/error_pb2_grpc.py +20 -0
  31. flwr/proto/exec_pb2.py +15 -5
  32. flwr/proto/exec_pb2_grpc.py +105 -24
  33. flwr/proto/fab_pb2.py +13 -3
  34. flwr/proto/fab_pb2_grpc.py +20 -0
  35. flwr/proto/fleet_pb2.py +15 -5
  36. flwr/proto/fleet_pb2_grpc.py +147 -36
  37. flwr/proto/grpcadapter_pb2.py +14 -4
  38. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  39. flwr/proto/log_pb2.py +13 -3
  40. flwr/proto/log_pb2_grpc.py +20 -0
  41. flwr/proto/message_pb2.py +15 -5
  42. flwr/proto/message_pb2_grpc.py +20 -0
  43. flwr/proto/node_pb2.py +15 -5
  44. flwr/proto/node_pb2.pyi +1 -4
  45. flwr/proto/node_pb2_grpc.py +20 -0
  46. flwr/proto/recordset_pb2.py +18 -8
  47. flwr/proto/recordset_pb2_grpc.py +20 -0
  48. flwr/proto/run_pb2.py +16 -6
  49. flwr/proto/run_pb2_grpc.py +20 -0
  50. flwr/proto/serverappio_pb2.py +32 -14
  51. flwr/proto/serverappio_pb2.pyi +56 -0
  52. flwr/proto/serverappio_pb2_grpc.py +261 -44
  53. flwr/proto/serverappio_pb2_grpc.pyi +20 -0
  54. flwr/proto/simulationio_pb2.py +13 -3
  55. flwr/proto/simulationio_pb2_grpc.py +105 -24
  56. flwr/proto/task_pb2.py +13 -3
  57. flwr/proto/task_pb2_grpc.py +20 -0
  58. flwr/proto/transport_pb2.py +20 -10
  59. flwr/proto/transport_pb2_grpc.py +35 -4
  60. flwr/server/app.py +40 -11
  61. flwr/server/compat/app_utils.py +0 -1
  62. flwr/server/compat/driver_client_proxy.py +1 -2
  63. flwr/server/driver/grpc_driver.py +5 -2
  64. flwr/server/driver/inmemory_driver.py +2 -1
  65. flwr/server/serverapp/app.py +5 -6
  66. flwr/server/superlink/driver/serverappio_servicer.py +110 -6
  67. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
  68. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +95 -169
  69. flwr/server/superlink/fleet/message_handler/message_handler.py +4 -5
  70. flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -3
  71. flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -26
  72. flwr/server/superlink/linkstate/linkstate.py +5 -18
  73. flwr/server/superlink/linkstate/sqlite_linkstate.py +30 -70
  74. flwr/server/superlink/linkstate/utils.py +18 -8
  75. flwr/server/utils/validator.py +9 -34
  76. flwr/simulation/app.py +4 -6
  77. flwr/simulation/legacy_app.py +4 -2
  78. {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/METADATA +4 -4
  79. {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/RECORD +82 -78
  80. {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/LICENSE +0 -0
  81. {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/WHEEL +0 -0
  82. {flwr_nightly-1.15.0.dev20250114.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/entry_points.txt +0 -0
@@ -16,14 +16,13 @@
16
16
 
17
17
 
18
18
  import threading
19
- import time
20
19
  from logging import DEBUG, INFO
21
20
  from typing import Optional
22
21
  from uuid import UUID
23
22
 
24
23
  import grpc
25
24
 
26
- from flwr.common import ConfigsRecord
25
+ from flwr.common import ConfigsRecord, now
27
26
  from flwr.common.constant import Status
28
27
  from flwr.common.logger import log
29
28
  from flwr.common.serde import (
@@ -31,6 +30,10 @@ from flwr.common.serde import (
31
30
  context_to_proto,
32
31
  fab_from_proto,
33
32
  fab_to_proto,
33
+ message_from_proto,
34
+ message_from_taskres,
35
+ message_to_proto,
36
+ message_to_taskins,
34
37
  run_status_from_proto,
35
38
  run_status_to_proto,
36
39
  run_to_proto,
@@ -57,10 +60,14 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
57
60
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
58
61
  GetNodesRequest,
59
62
  GetNodesResponse,
63
+ PullResMessagesRequest,
64
+ PullResMessagesResponse,
60
65
  PullServerAppInputsRequest,
61
66
  PullServerAppInputsResponse,
62
67
  PullTaskResRequest,
63
68
  PullTaskResResponse,
69
+ PushInsMessagesRequest,
70
+ PushInsMessagesResponse,
64
71
  PushServerAppOutputsRequest,
65
72
  PushServerAppOutputsResponse,
66
73
  PushTaskInsRequest,
@@ -102,9 +109,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
102
109
  )
103
110
 
104
111
  all_ids: set[int] = state.get_nodes(request.run_id)
105
- nodes: list[Node] = [
106
- Node(node_id=node_id, anonymous=False) for node_id in all_ids
107
- ]
112
+ nodes: list[Node] = [Node(node_id=node_id) for node_id in all_ids]
108
113
  return GetNodesResponse(nodes=nodes)
109
114
 
110
115
  def CreateRun(
@@ -151,7 +156,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
151
156
  )
152
157
 
153
158
  # Set pushed_at (timestamp in seconds)
154
- pushed_at = time.time()
159
+ pushed_at = now().timestamp()
155
160
  for task_ins in request.task_ins_list:
156
161
  task_ins.task.pushed_at = pushed_at
157
162
 
@@ -184,6 +189,59 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
184
189
  task_ids=[str(task_id) if task_id else "" for task_id in task_ids]
185
190
  )
186
191
 
192
+ def PushMessages(
193
+ self, request: PushInsMessagesRequest, context: grpc.ServicerContext
194
+ ) -> PushInsMessagesResponse:
195
+ """Push a set of Messages."""
196
+ log(DEBUG, "ServerAppIoServicer.PushMessages")
197
+
198
+ # Init state
199
+ state: LinkState = self.state_factory.state()
200
+
201
+ # Abort if the run is not running
202
+ abort_if(
203
+ request.run_id,
204
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
205
+ state,
206
+ context,
207
+ )
208
+
209
+ # Set pushed_at (timestamp in seconds)
210
+ pushed_at = now().timestamp()
211
+
212
+ # Validate request and insert in State
213
+ _raise_if(
214
+ validation_error=len(request.messages_list) == 0,
215
+ request_name="PushMessages",
216
+ detail="`messages_list` must not be empty",
217
+ )
218
+ message_ids: list[Optional[UUID]] = []
219
+ while request.messages_list:
220
+ message_proto = request.messages_list.pop(0)
221
+ message = message_from_proto(message_proto=message_proto)
222
+ task_ins = message_to_taskins(message=message)
223
+ task_ins.task.pushed_at = pushed_at
224
+ validation_errors = validate_task_ins_or_res(task_ins)
225
+ _raise_if(
226
+ validation_error=bool(validation_errors),
227
+ request_name="PushMessages",
228
+ detail=", ".join(validation_errors),
229
+ )
230
+ _raise_if(
231
+ validation_error=request.run_id != task_ins.run_id,
232
+ request_name="PushMessages",
233
+ detail="`task_ins` has mismatched `run_id`",
234
+ )
235
+ # Store
236
+ message_id: Optional[UUID] = state.store_task_ins(task_ins=task_ins)
237
+ message_ids.append(message_id)
238
+
239
+ return PushInsMessagesResponse(
240
+ message_ids=[
241
+ str(message_id) if message_id else "" for message_id in message_ids
242
+ ]
243
+ )
244
+
187
245
  def PullTaskRes(
188
246
  self, request: PullTaskResRequest, context: grpc.ServicerContext
189
247
  ) -> PullTaskResResponse:
@@ -223,6 +281,52 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
223
281
 
224
282
  return PullTaskResResponse(task_res_list=task_res_list)
225
283
 
284
+ def PullMessages(
285
+ self, request: PullResMessagesRequest, context: grpc.ServicerContext
286
+ ) -> PullResMessagesResponse:
287
+ """Pull a set of Messages."""
288
+ log(DEBUG, "ServerAppIoServicer.PullMessages")
289
+
290
+ # Init state
291
+ state: LinkState = self.state_factory.state()
292
+
293
+ # Abort if the run is not running
294
+ abort_if(
295
+ request.run_id,
296
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
297
+ state,
298
+ context,
299
+ )
300
+
301
+ # Convert each task_id str to UUID
302
+ message_ids: set[UUID] = {
303
+ UUID(message_id) for message_id in request.message_ids
304
+ }
305
+
306
+ # Read from state
307
+ task_res_list: list[TaskRes] = state.get_task_res(task_ids=message_ids)
308
+
309
+ # Convert to Messages
310
+ messages_list = []
311
+ while task_res_list:
312
+ task_res = task_res_list.pop(0)
313
+ _raise_if(
314
+ validation_error=request.run_id != task_res.run_id,
315
+ request_name="PullMessages",
316
+ detail="`task_res` has mismatched `run_id`",
317
+ )
318
+ message = message_from_taskres(taskres=task_res)
319
+ messages_list.append(message_to_proto(message))
320
+
321
+ # Delete the TaskIns/TaskRes pairs if TaskRes is found
322
+ task_ins_ids_to_delete = {
323
+ UUID(task_res.task.ancestry[0]) for task_res in task_res_list
324
+ }
325
+
326
+ state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)
327
+
328
+ return PullResMessagesResponse(messages_list=messages_list)
329
+
226
330
  def GetRun(
227
331
  self, request: GetRunRequest, context: grpc.ServicerContext
228
332
  ) -> GetRunResponse:
@@ -15,7 +15,7 @@
15
15
  """Fleet API gRPC adapter servicer."""
16
16
 
17
17
 
18
- from logging import DEBUG, INFO
18
+ from logging import DEBUG
19
19
  from typing import Callable, TypeVar
20
20
 
21
21
  import grpc
@@ -31,35 +31,30 @@ from flwr.common.constant import (
31
31
  from flwr.common.logger import log
32
32
  from flwr.common.version import package_name, package_version
33
33
  from flwr.proto import grpcadapter_pb2_grpc # pylint: disable=E0611
34
- from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
34
+ from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
35
35
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
36
36
  CreateNodeRequest,
37
- CreateNodeResponse,
38
37
  DeleteNodeRequest,
39
- DeleteNodeResponse,
40
38
  PingRequest,
41
- PingResponse,
42
- PullTaskInsRequest,
43
- PullTaskInsResponse,
44
- PushTaskResRequest,
45
- PushTaskResResponse,
39
+ PullMessagesRequest,
40
+ PushMessagesRequest,
46
41
  )
47
42
  from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
48
- from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
49
- from flwr.server.superlink.ffs.ffs_factory import FfsFactory
50
- from flwr.server.superlink.fleet.message_handler import message_handler
51
- from flwr.server.superlink.linkstate import LinkStateFactory
43
+ from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
44
+
45
+ from ..grpc_rere.fleet_servicer import FleetServicer
52
46
 
53
47
  T = TypeVar("T", bound=GrpcMessage)
54
48
 
55
49
 
56
50
  def _handle(
57
51
  msg_container: MessageContainer,
52
+ context: grpc.ServicerContext,
58
53
  request_type: type[T],
59
- handler: Callable[[T], GrpcMessage],
54
+ handler: Callable[[T, grpc.ServicerContext], GrpcMessage],
60
55
  ) -> MessageContainer:
61
56
  req = request_type.FromString(msg_container.grpc_message_content)
62
- res = handler(req)
57
+ res = handler(req, context)
63
58
  res_cls = res.__class__
64
59
  return MessageContainer(
65
60
  metadata={
@@ -74,89 +69,26 @@ def _handle(
74
69
  )
75
70
 
76
71
 
77
- class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
72
+ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer, FleetServicer):
78
73
  """Fleet API via GrpcAdapter servicer."""
79
74
 
80
- def __init__(
81
- self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
82
- ) -> None:
83
- self.state_factory = state_factory
84
- self.ffs_factory = ffs_factory
85
-
86
75
  def SendReceive( # pylint: disable=too-many-return-statements
87
76
  self, request: MessageContainer, context: grpc.ServicerContext
88
77
  ) -> MessageContainer:
89
78
  """."""
90
79
  log(DEBUG, "GrpcAdapterServicer.SendReceive")
91
80
  if request.grpc_message_name == CreateNodeRequest.__qualname__:
92
- return _handle(request, CreateNodeRequest, self._create_node)
81
+ return _handle(request, context, CreateNodeRequest, self.CreateNode)
93
82
  if request.grpc_message_name == DeleteNodeRequest.__qualname__:
94
- return _handle(request, DeleteNodeRequest, self._delete_node)
83
+ return _handle(request, context, DeleteNodeRequest, self.DeleteNode)
95
84
  if request.grpc_message_name == PingRequest.__qualname__:
96
- return _handle(request, PingRequest, self._ping)
97
- if request.grpc_message_name == PullTaskInsRequest.__qualname__:
98
- return _handle(request, PullTaskInsRequest, self._pull_task_ins)
99
- if request.grpc_message_name == PushTaskResRequest.__qualname__:
100
- return _handle(request, PushTaskResRequest, self._push_task_res)
85
+ return _handle(request, context, PingRequest, self.Ping)
101
86
  if request.grpc_message_name == GetRunRequest.__qualname__:
102
- return _handle(request, GetRunRequest, self._get_run)
87
+ return _handle(request, context, GetRunRequest, self.GetRun)
103
88
  if request.grpc_message_name == GetFabRequest.__qualname__:
104
- return _handle(request, GetFabRequest, self._get_fab)
89
+ return _handle(request, context, GetFabRequest, self.GetFab)
90
+ if request.grpc_message_name == PullMessagesRequest.__qualname__:
91
+ return _handle(request, context, PullMessagesRequest, self.PullMessages)
92
+ if request.grpc_message_name == PushMessagesRequest.__qualname__:
93
+ return _handle(request, context, PushMessagesRequest, self.PushMessages)
105
94
  raise ValueError(f"Invalid grpc_message_name: {request.grpc_message_name}")
106
-
107
- def _create_node(self, request: CreateNodeRequest) -> CreateNodeResponse:
108
- """."""
109
- log(INFO, "GrpcAdapter.CreateNode")
110
- return message_handler.create_node(
111
- request=request,
112
- state=self.state_factory.state(),
113
- )
114
-
115
- def _delete_node(self, request: DeleteNodeRequest) -> DeleteNodeResponse:
116
- """."""
117
- log(INFO, "GrpcAdapter.DeleteNode")
118
- return message_handler.delete_node(
119
- request=request,
120
- state=self.state_factory.state(),
121
- )
122
-
123
- def _ping(self, request: PingRequest) -> PingResponse:
124
- """."""
125
- log(DEBUG, "GrpcAdapter.Ping")
126
- return message_handler.ping(
127
- request=request,
128
- state=self.state_factory.state(),
129
- )
130
-
131
- def _pull_task_ins(self, request: PullTaskInsRequest) -> PullTaskInsResponse:
132
- """Pull TaskIns."""
133
- log(INFO, "GrpcAdapter.PullTaskIns")
134
- return message_handler.pull_task_ins(
135
- request=request,
136
- state=self.state_factory.state(),
137
- )
138
-
139
- def _push_task_res(self, request: PushTaskResRequest) -> PushTaskResResponse:
140
- """Push TaskRes."""
141
- log(INFO, "GrpcAdapter.PushTaskRes")
142
- return message_handler.push_task_res(
143
- request=request,
144
- state=self.state_factory.state(),
145
- )
146
-
147
- def _get_run(self, request: GetRunRequest) -> GetRunResponse:
148
- """Get run information."""
149
- log(INFO, "GrpcAdapter.GetRun")
150
- return message_handler.get_run(
151
- request=request,
152
- state=self.state_factory.state(),
153
- )
154
-
155
- def _get_fab(self, request: GetFabRequest) -> GetFabResponse:
156
- """Get FAB."""
157
- log(INFO, "GrpcAdapter.GetFab")
158
- return message_handler.get_fab(
159
- request=request,
160
- ffs=self.ffs_factory.ffs(),
161
- state=self.state_factory.state(),
162
- )
@@ -15,91 +15,54 @@
15
15
  """Flower server interceptor."""
16
16
 
17
17
 
18
- import base64
19
- from collections.abc import Sequence
20
- from logging import INFO, WARNING
21
- from typing import Any, Callable, Optional, Union
18
+ import datetime
19
+ from typing import Any, Callable, Optional, cast
22
20
 
23
21
  import grpc
24
- from cryptography.hazmat.primitives.asymmetric import ec
25
-
26
- from flwr.common.logger import log
22
+ from google.protobuf.message import Message as GrpcMessage
23
+
24
+ from flwr.common import now
25
+ from flwr.common.constant import (
26
+ PUBLIC_KEY_HEADER,
27
+ SIGNATURE_HEADER,
28
+ TIMESTAMP_HEADER,
29
+ TIMESTAMP_TOLERANCE,
30
+ )
27
31
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
28
- bytes_to_private_key,
29
32
  bytes_to_public_key,
30
- generate_shared_key,
31
- verify_hmac,
33
+ verify_signature,
32
34
  )
33
- from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
34
35
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
35
36
  CreateNodeRequest,
36
37
  CreateNodeResponse,
37
- DeleteNodeRequest,
38
- DeleteNodeResponse,
39
- PingRequest,
40
- PingResponse,
41
- PullTaskInsRequest,
42
- PullTaskInsResponse,
43
- PushTaskResRequest,
44
- PushTaskResResponse,
45
38
  )
46
- from flwr.proto.node_pb2 import Node # pylint: disable=E0611
47
- from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
48
39
  from flwr.server.superlink.linkstate import LinkStateFactory
49
40
 
50
- _PUBLIC_KEY_HEADER = "public-key"
51
- _AUTH_TOKEN_HEADER = "auth-token"
52
-
53
- Request = Union[
54
- CreateNodeRequest,
55
- DeleteNodeRequest,
56
- PullTaskInsRequest,
57
- PushTaskResRequest,
58
- GetRunRequest,
59
- PingRequest,
60
- GetFabRequest,
61
- ]
62
-
63
- Response = Union[
64
- CreateNodeResponse,
65
- DeleteNodeResponse,
66
- PullTaskInsResponse,
67
- PushTaskResResponse,
68
- GetRunResponse,
69
- PingResponse,
70
- GetFabResponse,
71
- ]
72
-
73
41
 
74
- def _get_value_from_tuples(
75
- key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]]
76
- ) -> bytes:
77
- value = next((value for key, value in tuples if key == key_string), "")
78
- if isinstance(value, str):
79
- return value.encode()
42
+ def _unary_unary_rpc_terminator(message: str) -> grpc.RpcMethodHandler:
43
+ def terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMessage:
44
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
45
+ raise RuntimeError("Should not reach this point") # Make mypy happy
80
46
 
81
- return value
47
+ return grpc.unary_unary_rpc_method_handler(terminate)
82
48
 
83
49
 
84
50
  class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
85
- """Server interceptor for node authentication."""
86
-
87
- def __init__(self, state_factory: LinkStateFactory):
51
+ """Server interceptor for node authentication.
52
+
53
+ Parameters
54
+ ----------
55
+ state_factory : LinkStateFactory
56
+ A factory for creating new instances of LinkState.
57
+ auto_auth : bool (default: False)
58
+ If True, nodes are authenticated without requiring their public keys to be
59
+ pre-stored in the LinkState. If False, only nodes with pre-stored public keys
60
+ can be authenticated.
61
+ """
62
+
63
+ def __init__(self, state_factory: LinkStateFactory, auto_auth: bool = False):
88
64
  self.state_factory = state_factory
89
- state = self.state_factory.state()
90
-
91
- self.node_public_keys = state.get_node_public_keys()
92
- if len(self.node_public_keys) == 0:
93
- log(WARNING, "Authentication enabled, but no known public keys configured")
94
-
95
- private_key = state.get_server_private_key()
96
- public_key = state.get_server_public_key()
97
-
98
- if private_key is None or public_key is None:
99
- raise ValueError("Error loading authentication keys")
100
-
101
- self.server_private_key = bytes_to_private_key(private_key)
102
- self.encoded_server_public_key = base64.urlsafe_b64encode(public_key)
65
+ self.auto_auth = auto_auth
103
66
 
104
67
  def intercept_service(
105
68
  self,
@@ -112,117 +75,80 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
112
75
  metadata sent by the node. Continue RPC call if node is authenticated, else,
113
76
  terminate RPC call by setting context to abort.
114
77
  """
78
+ state = self.state_factory.state()
79
+ metadata_dict = dict(handler_call_details.invocation_metadata)
80
+
81
+ # Retrieve info from the metadata
82
+ try:
83
+ node_pk_bytes = cast(bytes, metadata_dict[PUBLIC_KEY_HEADER])
84
+ timestamp_iso = cast(str, metadata_dict[TIMESTAMP_HEADER])
85
+ signature = cast(bytes, metadata_dict[SIGNATURE_HEADER])
86
+ except KeyError:
87
+ return _unary_unary_rpc_terminator("Missing authentication metadata")
88
+
89
+ if not self.auto_auth:
90
+ # Abort the RPC call if the node public key is not found
91
+ if node_pk_bytes not in state.get_node_public_keys():
92
+ return _unary_unary_rpc_terminator("Public key not recognized")
93
+
94
+ # Verify the signature
95
+ node_pk = bytes_to_public_key(node_pk_bytes)
96
+ if not verify_signature(node_pk, timestamp_iso.encode("ascii"), signature):
97
+ return _unary_unary_rpc_terminator("Invalid signature")
98
+
99
+ # Verify the timestamp
100
+ current = now()
101
+ time_diff = current - datetime.datetime.fromisoformat(timestamp_iso)
102
+ # Abort the RPC call if the timestamp is too old or in the future
103
+ if not 0 < time_diff.total_seconds() < TIMESTAMP_TOLERANCE:
104
+ return _unary_unary_rpc_terminator("Invalid timestamp")
105
+
106
+ # Continue the RPC call
107
+ expected_node_id = state.get_node_id(node_pk_bytes)
108
+ if not handler_call_details.method.endswith("CreateNode"):
109
+ if expected_node_id is None:
110
+ return _unary_unary_rpc_terminator("Invalid node ID")
115
111
  # One of the method handlers in
116
112
  # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
117
113
  method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
118
- return self._generic_auth_unary_method_handler(method_handler)
114
+ return self._wrap_method_handler(
115
+ method_handler, expected_node_id, node_pk_bytes
116
+ )
119
117
 
120
- def _generic_auth_unary_method_handler(
121
- self, method_handler: grpc.RpcMethodHandler
118
+ def _wrap_method_handler(
119
+ self,
120
+ method_handler: grpc.RpcMethodHandler,
121
+ expected_node_id: Optional[int],
122
+ node_public_key: bytes,
122
123
  ) -> grpc.RpcMethodHandler:
123
124
  def _generic_method_handler(
124
- request: Request,
125
+ request: GrpcMessage,
125
126
  context: grpc.ServicerContext,
126
- ) -> Response:
127
- node_public_key_bytes = base64.urlsafe_b64decode(
128
- _get_value_from_tuples(
129
- _PUBLIC_KEY_HEADER, context.invocation_metadata()
130
- )
131
- )
132
- if node_public_key_bytes not in self.node_public_keys:
133
- context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
134
-
135
- if isinstance(request, CreateNodeRequest):
136
- response = self._create_authenticated_node(
137
- node_public_key_bytes, request, context
138
- )
139
- log(
140
- INFO,
141
- "AuthenticateServerInterceptor: Created node_id=%s",
142
- response.node.node_id,
143
- )
144
- return response
145
-
146
- # Verify hmac value
147
- hmac_value = base64.urlsafe_b64decode(
148
- _get_value_from_tuples(
149
- _AUTH_TOKEN_HEADER, context.invocation_metadata()
150
- )
151
- )
152
- public_key = bytes_to_public_key(node_public_key_bytes)
153
-
154
- if not self._verify_hmac(public_key, request, hmac_value):
155
- context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
156
-
157
- # Verify node_id
158
- node_id = self.state_factory.state().get_node_id(node_public_key_bytes)
159
-
160
- if not self._verify_node_id(node_id, request):
161
- context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
162
-
163
- return method_handler.unary_unary(request, context) # type: ignore
127
+ ) -> GrpcMessage:
128
+ # Verify the node ID
129
+ if not isinstance(request, CreateNodeRequest):
130
+ try:
131
+ if request.node.node_id != expected_node_id: # type: ignore
132
+ raise ValueError
133
+ except (AttributeError, ValueError):
134
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
135
+
136
+ response: GrpcMessage = method_handler.unary_unary(request, context)
137
+
138
+ # Set the public key after a successful CreateNode request
139
+ if isinstance(response, CreateNodeResponse):
140
+ state = self.state_factory.state()
141
+ try:
142
+ state.set_node_public_key(response.node.node_id, node_public_key)
143
+ except ValueError as e:
144
+ # Remove newly created node if setting the public key fails
145
+ state.delete_node(response.node.node_id)
146
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
147
+
148
+ return response
164
149
 
165
150
  return grpc.unary_unary_rpc_method_handler(
166
151
  _generic_method_handler,
167
152
  request_deserializer=method_handler.request_deserializer,
168
153
  response_serializer=method_handler.response_serializer,
169
154
  )
170
-
171
- def _verify_node_id(
172
- self,
173
- node_id: Optional[int],
174
- request: Union[
175
- DeleteNodeRequest,
176
- PullTaskInsRequest,
177
- PushTaskResRequest,
178
- GetRunRequest,
179
- PingRequest,
180
- GetFabRequest,
181
- ],
182
- ) -> bool:
183
- if node_id is None:
184
- return False
185
- if isinstance(request, PushTaskResRequest):
186
- if len(request.task_res_list) == 0:
187
- return False
188
- return request.task_res_list[0].task.producer.node_id == node_id
189
- if isinstance(request, GetRunRequest):
190
- return node_id in self.state_factory.state().get_nodes(request.run_id)
191
- return request.node.node_id == node_id
192
-
193
- def _verify_hmac(
194
- self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
195
- ) -> bool:
196
- shared_secret = generate_shared_key(self.server_private_key, public_key)
197
- message_bytes = request.SerializeToString(deterministic=True)
198
- return verify_hmac(shared_secret, message_bytes, hmac_value)
199
-
200
- def _create_authenticated_node(
201
- self,
202
- public_key_bytes: bytes,
203
- request: CreateNodeRequest,
204
- context: grpc.ServicerContext,
205
- ) -> CreateNodeResponse:
206
- context.send_initial_metadata(
207
- (
208
- (
209
- _PUBLIC_KEY_HEADER,
210
- self.encoded_server_public_key,
211
- ),
212
- )
213
- )
214
- state = self.state_factory.state()
215
- node_id = state.get_node_id(public_key_bytes)
216
-
217
- # Handle `CreateNode` here instead of calling the default method handler
218
- # Return previously assigned `node_id` for the provided `public_key`
219
- if node_id is not None:
220
- state.acknowledge_ping(node_id, request.ping_interval)
221
- return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
222
-
223
- # No `node_id` exists for the provided `public_key`
224
- # Handle `CreateNode` here instead of calling the default method handler
225
- # Note: the innermost `CreateNode` method will never be called
226
- node_id = state.create_node(request.ping_interval)
227
- state.set_node_public_key(node_id, public_key_bytes)
228
- return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
@@ -66,13 +66,13 @@ def create_node(
66
66
  """."""
67
67
  # Create node
68
68
  node_id = state.create_node(ping_interval=request.ping_interval)
69
- return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
69
+ return CreateNodeResponse(node=Node(node_id=node_id))
70
70
 
71
71
 
72
72
  def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
73
73
  """."""
74
74
  # Validate node_id
75
- if request.node.anonymous or request.node.node_id == 0:
75
+ if request.node.node_id == 0: # i.e. unset `node_id`
76
76
  return DeleteNodeResponse()
77
77
 
78
78
  # Update state
@@ -91,9 +91,8 @@ def ping(
91
91
 
92
92
  def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
93
93
  """Pull TaskIns handler."""
94
- # Get node_id if client node is not anonymous
95
94
  node = request.node # pylint: disable=no-member
96
- node_id: Optional[int] = None if node.anonymous else node.node_id
95
+ node_id: int = node.node_id
97
96
 
98
97
  # Retrieve TaskIns from State
99
98
  task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)
@@ -111,7 +110,7 @@ def pull_messages(
111
110
  """Pull Messages handler."""
112
111
  # Get node_id if client node is not anonymous
113
112
  node = request.node # pylint: disable=no-member
114
- node_id: Optional[int] = None if node.anonymous else node.node_id
113
+ node_id: int = node.node_id
115
114
 
116
115
  # Retrieve TaskIns from State
117
116
  task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)