flwr 1.14.0__py3-none-any.whl → 1.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. flwr/cli/auth_plugin/__init__.py +31 -0
  2. flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
  3. flwr/cli/cli_user_auth_interceptor.py +6 -2
  4. flwr/cli/config_utils.py +24 -147
  5. flwr/cli/constant.py +27 -0
  6. flwr/cli/install.py +1 -1
  7. flwr/cli/log.py +18 -3
  8. flwr/cli/login/login.py +43 -8
  9. flwr/cli/ls.py +14 -5
  10. flwr/cli/new/templates/app/README.md.tpl +3 -2
  11. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  12. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  13. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  14. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
  15. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  16. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
  17. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
  18. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
  19. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  20. flwr/cli/run/run.py +21 -11
  21. flwr/cli/stop.py +13 -4
  22. flwr/cli/utils.py +54 -40
  23. flwr/client/app.py +36 -48
  24. flwr/client/clientapp/app.py +19 -25
  25. flwr/client/clientapp/utils.py +1 -1
  26. flwr/client/grpc_client/connection.py +1 -12
  27. flwr/client/grpc_rere_client/client_interceptor.py +19 -119
  28. flwr/client/grpc_rere_client/connection.py +46 -36
  29. flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
  30. flwr/client/message_handler/task_handler.py +0 -17
  31. flwr/client/rest_client/connection.py +34 -26
  32. flwr/client/supernode/app.py +18 -72
  33. flwr/common/args.py +25 -47
  34. flwr/common/auth_plugin/auth_plugin.py +34 -23
  35. flwr/common/config.py +166 -16
  36. flwr/common/constant.py +24 -9
  37. flwr/common/differential_privacy.py +2 -1
  38. flwr/common/exit/__init__.py +24 -0
  39. flwr/common/exit/exit.py +99 -0
  40. flwr/common/exit/exit_code.py +93 -0
  41. flwr/common/exit_handlers.py +32 -30
  42. flwr/common/grpc.py +167 -4
  43. flwr/common/logger.py +26 -7
  44. flwr/common/object_ref.py +0 -14
  45. flwr/common/record/recordset.py +1 -1
  46. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
  47. flwr/common/serde.py +6 -4
  48. flwr/common/typing.py +20 -0
  49. flwr/proto/clientappio_pb2.py +1 -1
  50. flwr/proto/error_pb2.py +1 -1
  51. flwr/proto/exec_pb2.py +13 -25
  52. flwr/proto/exec_pb2.pyi +27 -54
  53. flwr/proto/fab_pb2.py +1 -1
  54. flwr/proto/fleet_pb2.py +31 -31
  55. flwr/proto/fleet_pb2.pyi +23 -23
  56. flwr/proto/fleet_pb2_grpc.py +30 -30
  57. flwr/proto/fleet_pb2_grpc.pyi +20 -20
  58. flwr/proto/grpcadapter_pb2.py +1 -1
  59. flwr/proto/log_pb2.py +1 -1
  60. flwr/proto/message_pb2.py +1 -1
  61. flwr/proto/node_pb2.py +3 -3
  62. flwr/proto/node_pb2.pyi +1 -4
  63. flwr/proto/recordset_pb2.py +1 -1
  64. flwr/proto/run_pb2.py +1 -1
  65. flwr/proto/serverappio_pb2.py +24 -25
  66. flwr/proto/serverappio_pb2.pyi +26 -32
  67. flwr/proto/serverappio_pb2_grpc.py +28 -28
  68. flwr/proto/serverappio_pb2_grpc.pyi +16 -16
  69. flwr/proto/simulationio_pb2.py +1 -1
  70. flwr/proto/task_pb2.py +1 -1
  71. flwr/proto/transport_pb2.py +1 -1
  72. flwr/server/app.py +116 -128
  73. flwr/server/compat/app_utils.py +0 -1
  74. flwr/server/compat/driver_client_proxy.py +1 -2
  75. flwr/server/driver/grpc_driver.py +32 -27
  76. flwr/server/driver/inmemory_driver.py +2 -1
  77. flwr/server/serverapp/app.py +12 -10
  78. flwr/server/superlink/driver/serverappio_grpc.py +1 -1
  79. flwr/server/superlink/driver/serverappio_servicer.py +74 -48
  80. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
  81. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
  82. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -24
  83. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +110 -168
  84. flwr/server/superlink/fleet/message_handler/message_handler.py +37 -24
  85. flwr/server/superlink/fleet/rest_rere/rest_api.py +16 -18
  86. flwr/server/superlink/fleet/vce/vce_api.py +2 -2
  87. flwr/server/superlink/linkstate/in_memory_linkstate.py +45 -75
  88. flwr/server/superlink/linkstate/linkstate.py +17 -38
  89. flwr/server/superlink/linkstate/sqlite_linkstate.py +81 -145
  90. flwr/server/superlink/linkstate/utils.py +18 -8
  91. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  92. flwr/server/utils/validator.py +9 -34
  93. flwr/simulation/app.py +4 -6
  94. flwr/simulation/legacy_app.py +4 -2
  95. flwr/simulation/run_simulation.py +1 -1
  96. flwr/simulation/simulationio_connection.py +2 -1
  97. flwr/superexec/exec_grpc.py +1 -1
  98. flwr/superexec/exec_servicer.py +23 -2
  99. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/METADATA +8 -8
  100. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/RECORD +103 -97
  101. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/LICENSE +0 -0
  102. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/WHEEL +0 -0
  103. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/entry_points.txt +0 -0
@@ -18,6 +18,7 @@
18
18
  from logging import DEBUG, INFO
19
19
 
20
20
  import grpc
21
+ from google.protobuf.json_format import MessageToDict
21
22
 
22
23
  from flwr.common.logger import log
23
24
  from flwr.common.typing import InvalidRunStatusException
@@ -30,10 +31,10 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
30
31
  DeleteNodeResponse,
31
32
  PingRequest,
32
33
  PingResponse,
33
- PullTaskInsRequest,
34
- PullTaskInsResponse,
35
- PushTaskResRequest,
36
- PushTaskResResponse,
34
+ PullMessagesRequest,
35
+ PullMessagesResponse,
36
+ PushMessagesRequest,
37
+ PushMessagesResponse,
37
38
  )
38
39
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
39
40
  from flwr.server.superlink.ffs.ffs_factory import FfsFactory
@@ -56,13 +57,13 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
56
57
  ) -> CreateNodeResponse:
57
58
  """."""
58
59
  log(INFO, "[Fleet.CreateNode] Request ping_interval=%s", request.ping_interval)
59
- log(DEBUG, "[Fleet.CreateNode] Request: %s", request)
60
+ log(DEBUG, "[Fleet.CreateNode] Request: %s", MessageToDict(request))
60
61
  response = message_handler.create_node(
61
62
  request=request,
62
63
  state=self.state_factory.state(),
63
64
  )
64
65
  log(INFO, "[Fleet.CreateNode] Created node_id=%s", response.node.node_id)
65
- log(DEBUG, "[Fleet.CreateNode] Response: %s", response)
66
+ log(DEBUG, "[Fleet.CreateNode] Response: %s", MessageToDict(response))
66
67
  return response
67
68
 
68
69
  def DeleteNode(
@@ -70,7 +71,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
70
71
  ) -> DeleteNodeResponse:
71
72
  """."""
72
73
  log(INFO, "[Fleet.DeleteNode] Delete node_id=%s", request.node.node_id)
73
- log(DEBUG, "[Fleet.DeleteNode] Request: %s", request)
74
+ log(DEBUG, "[Fleet.DeleteNode] Request: %s", MessageToDict(request))
74
75
  return message_handler.delete_node(
75
76
  request=request,
76
77
  state=self.state_factory.state(),
@@ -78,38 +79,38 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
78
79
 
79
80
  def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse:
80
81
  """."""
81
- log(DEBUG, "[Fleet.Ping] Request: %s", request)
82
+ log(DEBUG, "[Fleet.Ping] Request: %s", MessageToDict(request))
82
83
  return message_handler.ping(
83
84
  request=request,
84
85
  state=self.state_factory.state(),
85
86
  )
86
87
 
87
- def PullTaskIns(
88
- self, request: PullTaskInsRequest, context: grpc.ServicerContext
89
- ) -> PullTaskInsResponse:
90
- """Pull TaskIns."""
91
- log(INFO, "[Fleet.PullTaskIns] node_id=%s", request.node.node_id)
92
- log(DEBUG, "[Fleet.PullTaskIns] Request: %s", request)
93
- return message_handler.pull_task_ins(
88
+ def PullMessages(
89
+ self, request: PullMessagesRequest, context: grpc.ServicerContext
90
+ ) -> PullMessagesResponse:
91
+ """Pull Messages."""
92
+ log(INFO, "[Fleet.PullMessages] node_id=%s", request.node.node_id)
93
+ log(DEBUG, "[Fleet.PullMessages] Request: %s", MessageToDict(request))
94
+ return message_handler.pull_messages(
94
95
  request=request,
95
96
  state=self.state_factory.state(),
96
97
  )
97
98
 
98
- def PushTaskRes(
99
- self, request: PushTaskResRequest, context: grpc.ServicerContext
100
- ) -> PushTaskResResponse:
101
- """Push TaskRes."""
102
- if request.task_res_list:
99
+ def PushMessages(
100
+ self, request: PushMessagesRequest, context: grpc.ServicerContext
101
+ ) -> PushMessagesResponse:
102
+ """Push Messages."""
103
+ if request.messages_list:
103
104
  log(
104
105
  INFO,
105
- "[Fleet.PushTaskRes] Push results from node_id=%s",
106
- request.task_res_list[0].task.producer.node_id,
106
+ "[Fleet.PushMessages] Push results from node_id=%s",
107
+ request.messages_list[0].metadata.src_node_id,
107
108
  )
108
109
  else:
109
- log(INFO, "[Fleet.PushTaskRes] No task results to push")
110
+ log(INFO, "[Fleet.PushMessages] No task results to push")
110
111
 
111
112
  try:
112
- res = message_handler.push_task_res(
113
+ res = message_handler.push_messages(
113
114
  request=request,
114
115
  state=self.state_factory.state(),
115
116
  )
@@ -15,93 +15,62 @@
15
15
  """Flower server interceptor."""
16
16
 
17
17
 
18
- import base64
19
- from collections.abc import Sequence
20
- from logging import INFO, WARNING
21
- from typing import Any, Callable, Optional, Union
18
+ import datetime
19
+ from typing import Any, Callable, Optional, cast
22
20
 
23
21
  import grpc
24
- from cryptography.hazmat.primitives.asymmetric import ec
25
-
26
- from flwr.common.logger import log
22
+ from google.protobuf.message import Message as GrpcMessage
23
+
24
+ from flwr.common import now
25
+ from flwr.common.constant import (
26
+ PUBLIC_KEY_HEADER,
27
+ SIGNATURE_HEADER,
28
+ SYSTEM_TIME_TOLERANCE,
29
+ TIMESTAMP_HEADER,
30
+ TIMESTAMP_TOLERANCE,
31
+ )
27
32
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
28
- bytes_to_private_key,
29
33
  bytes_to_public_key,
30
- generate_shared_key,
31
- verify_hmac,
34
+ verify_signature,
32
35
  )
33
- from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
34
36
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
35
37
  CreateNodeRequest,
36
38
  CreateNodeResponse,
37
- DeleteNodeRequest,
38
- DeleteNodeResponse,
39
- PingRequest,
40
- PingResponse,
41
- PullTaskInsRequest,
42
- PullTaskInsResponse,
43
- PushTaskResRequest,
44
- PushTaskResResponse,
45
39
  )
46
- from flwr.proto.node_pb2 import Node # pylint: disable=E0611
47
- from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
48
40
  from flwr.server.superlink.linkstate import LinkStateFactory
49
41
 
50
- _PUBLIC_KEY_HEADER = "public-key"
51
- _AUTH_TOKEN_HEADER = "auth-token"
42
+ MIN_TIMESTAMP_DIFF = -SYSTEM_TIME_TOLERANCE
43
+ MAX_TIMESTAMP_DIFF = TIMESTAMP_TOLERANCE + SYSTEM_TIME_TOLERANCE
52
44
 
53
- Request = Union[
54
- CreateNodeRequest,
55
- DeleteNodeRequest,
56
- PullTaskInsRequest,
57
- PushTaskResRequest,
58
- GetRunRequest,
59
- PingRequest,
60
- GetFabRequest,
61
- ]
62
-
63
- Response = Union[
64
- CreateNodeResponse,
65
- DeleteNodeResponse,
66
- PullTaskInsResponse,
67
- PushTaskResResponse,
68
- GetRunResponse,
69
- PingResponse,
70
- GetFabResponse,
71
- ]
72
45
 
46
+ def _unary_unary_rpc_terminator(
47
+ message: str, code: Any = grpc.StatusCode.UNAUTHENTICATED
48
+ ) -> grpc.RpcMethodHandler:
49
+ def terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMessage:
50
+ context.abort(code, message)
51
+ raise RuntimeError("Should not reach this point") # Make mypy happy
73
52
 
74
- def _get_value_from_tuples(
75
- key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]]
76
- ) -> bytes:
77
- value = next((value for key, value in tuples if key == key_string), "")
78
- if isinstance(value, str):
79
- return value.encode()
80
-
81
- return value
53
+ return grpc.unary_unary_rpc_method_handler(terminate)
82
54
 
83
55
 
84
56
  class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
85
- """Server interceptor for node authentication."""
86
-
87
- def __init__(self, state_factory: LinkStateFactory):
57
+ """Server interceptor for node authentication.
58
+
59
+ Parameters
60
+ ----------
61
+ state_factory : LinkStateFactory
62
+ A factory for creating new instances of LinkState.
63
+ auto_auth : bool (default: False)
64
+ If True, nodes are authenticated without requiring their public keys to be
65
+ pre-stored in the LinkState. If False, only nodes with pre-stored public keys
66
+ can be authenticated.
67
+ """
68
+
69
+ def __init__(self, state_factory: LinkStateFactory, auto_auth: bool = False):
88
70
  self.state_factory = state_factory
89
- state = self.state_factory.state()
71
+ self.auto_auth = auto_auth
90
72
 
91
- self.node_public_keys = state.get_node_public_keys()
92
- if len(self.node_public_keys) == 0:
93
- log(WARNING, "Authentication enabled, but no known public keys configured")
94
-
95
- private_key = state.get_server_private_key()
96
- public_key = state.get_server_public_key()
97
-
98
- if private_key is None or public_key is None:
99
- raise ValueError("Error loading authentication keys")
100
-
101
- self.server_private_key = bytes_to_private_key(private_key)
102
- self.encoded_server_public_key = base64.urlsafe_b64encode(public_key)
103
-
104
- def intercept_service(
73
+ def intercept_service( # pylint: disable=too-many-return-statements
105
74
  self,
106
75
  continuation: Callable[[Any], Any],
107
76
  handler_call_details: grpc.HandlerCallDetails,
@@ -112,116 +81,89 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
112
81
  metadata sent by the node. Continue RPC call if node is authenticated, else,
113
82
  terminate RPC call by setting context to abort.
114
83
  """
84
+ # Filter out non-Fleet service calls
85
+ if not handler_call_details.method.startswith("/flwr.proto.Fleet/"):
86
+ return _unary_unary_rpc_terminator(
87
+ "This request should be sent to a different service.",
88
+ grpc.StatusCode.FAILED_PRECONDITION,
89
+ )
90
+
91
+ state = self.state_factory.state()
92
+ metadata_dict = dict(handler_call_details.invocation_metadata)
93
+
94
+ # Retrieve info from the metadata
95
+ try:
96
+ node_pk_bytes = cast(bytes, metadata_dict[PUBLIC_KEY_HEADER])
97
+ timestamp_iso = cast(str, metadata_dict[TIMESTAMP_HEADER])
98
+ signature = cast(bytes, metadata_dict[SIGNATURE_HEADER])
99
+ except KeyError:
100
+ return _unary_unary_rpc_terminator("Missing authentication metadata")
101
+
102
+ if not self.auto_auth:
103
+ # Abort the RPC call if the node public key is not found
104
+ if node_pk_bytes not in state.get_node_public_keys():
105
+ return _unary_unary_rpc_terminator("Public key not recognized")
106
+
107
+ # Verify the signature
108
+ node_pk = bytes_to_public_key(node_pk_bytes)
109
+ if not verify_signature(node_pk, timestamp_iso.encode("ascii"), signature):
110
+ return _unary_unary_rpc_terminator("Invalid signature")
111
+
112
+ # Verify the timestamp
113
+ current = now()
114
+ time_diff = current - datetime.datetime.fromisoformat(timestamp_iso)
115
+ # Abort the RPC call if the timestamp is too old or in the future
116
+ if not MIN_TIMESTAMP_DIFF < time_diff.total_seconds() < MAX_TIMESTAMP_DIFF:
117
+ return _unary_unary_rpc_terminator("Invalid timestamp")
118
+
119
+ # Continue the RPC call
120
+ expected_node_id = state.get_node_id(node_pk_bytes)
121
+ if not handler_call_details.method.endswith("CreateNode"):
122
+ # All calls, except for `CreateNode`, must provide a public key that is
123
+ # already mapped to a `node_id` (in `LinkState`)
124
+ if expected_node_id is None:
125
+ return _unary_unary_rpc_terminator("Invalid node ID")
115
126
  # One of the method handlers in
116
127
  # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
117
128
  method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
118
- return self._generic_auth_unary_method_handler(method_handler)
129
+ return self._wrap_method_handler(
130
+ method_handler, expected_node_id, node_pk_bytes
131
+ )
119
132
 
120
- def _generic_auth_unary_method_handler(
121
- self, method_handler: grpc.RpcMethodHandler
133
+ def _wrap_method_handler(
134
+ self,
135
+ method_handler: grpc.RpcMethodHandler,
136
+ expected_node_id: Optional[int],
137
+ node_public_key: bytes,
122
138
  ) -> grpc.RpcMethodHandler:
123
139
  def _generic_method_handler(
124
- request: Request,
140
+ request: GrpcMessage,
125
141
  context: grpc.ServicerContext,
126
- ) -> Response:
127
- node_public_key_bytes = base64.urlsafe_b64decode(
128
- _get_value_from_tuples(
129
- _PUBLIC_KEY_HEADER, context.invocation_metadata()
130
- )
131
- )
132
- if node_public_key_bytes not in self.node_public_keys:
133
- context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
134
-
135
- if isinstance(request, CreateNodeRequest):
136
- response = self._create_authenticated_node(
137
- node_public_key_bytes, request, context
138
- )
139
- log(
140
- INFO,
141
- "AuthenticateServerInterceptor: Created node_id=%s",
142
- response.node.node_id,
143
- )
144
- return response
145
-
146
- # Verify hmac value
147
- hmac_value = base64.urlsafe_b64decode(
148
- _get_value_from_tuples(
149
- _AUTH_TOKEN_HEADER, context.invocation_metadata()
150
- )
151
- )
152
- public_key = bytes_to_public_key(node_public_key_bytes)
153
-
154
- if not self._verify_hmac(public_key, request, hmac_value):
155
- context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
156
-
157
- # Verify node_id
158
- node_id = self.state_factory.state().get_node_id(node_public_key_bytes)
159
-
160
- if not self._verify_node_id(node_id, request):
161
- context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
162
-
163
- return method_handler.unary_unary(request, context) # type: ignore
142
+ ) -> GrpcMessage:
143
+ # Verify the node ID
144
+ if not isinstance(request, CreateNodeRequest):
145
+ try:
146
+ if request.node.node_id != expected_node_id: # type: ignore
147
+ raise ValueError
148
+ except (AttributeError, ValueError):
149
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
150
+
151
+ response: GrpcMessage = method_handler.unary_unary(request, context)
152
+
153
+ # Set the public key after a successful CreateNode request
154
+ if isinstance(response, CreateNodeResponse):
155
+ state = self.state_factory.state()
156
+ try:
157
+ state.set_node_public_key(response.node.node_id, node_public_key)
158
+ except ValueError as e:
159
+ # Remove newly created node if setting the public key fails
160
+ state.delete_node(response.node.node_id)
161
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
162
+
163
+ return response
164
164
 
165
165
  return grpc.unary_unary_rpc_method_handler(
166
166
  _generic_method_handler,
167
167
  request_deserializer=method_handler.request_deserializer,
168
168
  response_serializer=method_handler.response_serializer,
169
169
  )
170
-
171
- def _verify_node_id(
172
- self,
173
- node_id: Optional[int],
174
- request: Union[
175
- DeleteNodeRequest,
176
- PullTaskInsRequest,
177
- PushTaskResRequest,
178
- GetRunRequest,
179
- PingRequest,
180
- GetFabRequest,
181
- ],
182
- ) -> bool:
183
- if node_id is None:
184
- return False
185
- if isinstance(request, PushTaskResRequest):
186
- if len(request.task_res_list) == 0:
187
- return False
188
- return request.task_res_list[0].task.producer.node_id == node_id
189
- if isinstance(request, GetRunRequest):
190
- return node_id in self.state_factory.state().get_nodes(request.run_id)
191
- return request.node.node_id == node_id
192
-
193
- def _verify_hmac(
194
- self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
195
- ) -> bool:
196
- shared_secret = generate_shared_key(self.server_private_key, public_key)
197
- message_bytes = request.SerializeToString(deterministic=True)
198
- return verify_hmac(shared_secret, message_bytes, hmac_value)
199
-
200
- def _create_authenticated_node(
201
- self,
202
- public_key_bytes: bytes,
203
- request: CreateNodeRequest,
204
- context: grpc.ServicerContext,
205
- ) -> CreateNodeResponse:
206
- context.send_initial_metadata(
207
- (
208
- (
209
- _PUBLIC_KEY_HEADER,
210
- self.encoded_server_public_key,
211
- ),
212
- )
213
- )
214
- state = self.state_factory.state()
215
- node_id = state.get_node_id(public_key_bytes)
216
-
217
- # Handle `CreateNode` here instead of calling the default method handler
218
- # Return previously assigned `node_id` for the provided `public_key`
219
- if node_id is not None:
220
- state.acknowledge_ping(node_id, request.ping_interval)
221
- return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
222
-
223
- # No `node_id` exists for the provided `public_key`
224
- # Handle `CreateNode` here instead of calling the default method handler
225
- # Note: the innermost `CreateNode` method will never be called
226
- node_id = state.create_node(request.ping_interval, public_key_bytes)
227
- return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
@@ -20,7 +20,14 @@ from typing import Optional
20
20
  from uuid import UUID
21
21
 
22
22
  from flwr.common.constant import Status
23
- from flwr.common.serde import fab_to_proto, user_config_to_proto
23
+ from flwr.common.serde import (
24
+ fab_to_proto,
25
+ message_from_proto,
26
+ message_from_taskins,
27
+ message_to_proto,
28
+ message_to_taskres,
29
+ user_config_to_proto,
30
+ )
24
31
  from flwr.common.typing import Fab, InvalidRunStatusException
25
32
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
26
33
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
@@ -30,10 +37,10 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
30
37
  DeleteNodeResponse,
31
38
  PingRequest,
32
39
  PingResponse,
33
- PullTaskInsRequest,
34
- PullTaskInsResponse,
35
- PushTaskResRequest,
36
- PushTaskResResponse,
40
+ PullMessagesRequest,
41
+ PullMessagesResponse,
42
+ PushMessagesRequest,
43
+ PushMessagesResponse,
37
44
  Reconnect,
38
45
  )
39
46
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
@@ -42,7 +49,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
42
49
  GetRunResponse,
43
50
  Run,
44
51
  )
45
- from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
52
+ from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
46
53
  from flwr.server.superlink.ffs.ffs import Ffs
47
54
  from flwr.server.superlink.linkstate import LinkState
48
55
  from flwr.server.superlink.utils import check_abort
@@ -55,13 +62,13 @@ def create_node(
55
62
  """."""
56
63
  # Create node
57
64
  node_id = state.create_node(ping_interval=request.ping_interval)
58
- return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
65
+ return CreateNodeResponse(node=Node(node_id=node_id))
59
66
 
60
67
 
61
68
  def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
62
69
  """."""
63
70
  # Validate node_id
64
- if request.node.anonymous or request.node.node_id == 0:
71
+ if request.node.node_id == 0: # i.e. unset `node_id`
65
72
  return DeleteNodeResponse()
66
73
 
67
74
  # Update state
@@ -78,27 +85,33 @@ def ping(
78
85
  return PingResponse(success=res)
79
86
 
80
87
 
81
- def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
82
- """Pull TaskIns handler."""
88
+ def pull_messages(
89
+ request: PullMessagesRequest, state: LinkState
90
+ ) -> PullMessagesResponse:
91
+ """Pull Messages handler."""
83
92
  # Get node_id if client node is not anonymous
84
93
  node = request.node # pylint: disable=no-member
85
- node_id: Optional[int] = None if node.anonymous else node.node_id
94
+ node_id: int = node.node_id
86
95
 
87
96
  # Retrieve TaskIns from State
88
97
  task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)
89
98
 
90
- # Build response
91
- response = PullTaskInsResponse(
92
- task_ins_list=task_ins_list,
93
- )
94
- return response
99
+ # Convert to Messages
100
+ msg_proto = []
101
+ for task_ins in task_ins_list:
102
+ msg = message_from_taskins(task_ins)
103
+ msg_proto.append(message_to_proto(msg))
104
+
105
+ return PullMessagesResponse(messages_list=msg_proto)
95
106
 
96
107
 
97
- def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse:
98
- """Push TaskRes handler."""
99
- # pylint: disable=no-member
100
- task_res: TaskRes = request.task_res_list[0]
101
- # pylint: enable=no-member
108
+ def push_messages(
109
+ request: PushMessagesRequest, state: LinkState
110
+ ) -> PushMessagesResponse:
111
+ """Push Messages handler."""
112
+ # Convert Message to TaskRes
113
+ msg = message_from_proto(message_proto=request.messages_list[0])
114
+ task_res = message_to_taskres(msg)
102
115
 
103
116
  # Abort if the run is not running
104
117
  abort_msg = check_abort(
@@ -113,12 +126,12 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
113
126
  task_res.task.pushed_at = time.time()
114
127
 
115
128
  # Store TaskRes in State
116
- task_id: Optional[UUID] = state.store_task_res(task_res=task_res)
129
+ message_id: Optional[UUID] = state.store_task_res(task_res=task_res)
117
130
 
118
131
  # Build response
119
- response = PushTaskResResponse(
132
+ response = PushMessagesResponse(
120
133
  reconnect=Reconnect(reconnect=5),
121
- results={str(task_id): 0},
134
+ results={str(message_id): 0},
122
135
  )
123
136
  return response
124
137
 
@@ -17,13 +17,12 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import sys
21
20
  from collections.abc import Awaitable
22
21
  from typing import Callable, TypeVar, cast
23
22
 
24
23
  from google.protobuf.message import Message as GrpcMessage
25
24
 
26
- from flwr.common.constant import MISSING_EXTRA_REST
25
+ from flwr.common.exit import ExitCode, flwr_exit
27
26
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
28
27
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
29
28
  CreateNodeRequest,
@@ -32,10 +31,10 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
31
  DeleteNodeResponse,
33
32
  PingRequest,
34
33
  PingResponse,
35
- PullTaskInsRequest,
36
- PullTaskInsResponse,
37
- PushTaskResRequest,
38
- PushTaskResResponse,
34
+ PullMessagesRequest,
35
+ PullMessagesResponse,
36
+ PushMessagesRequest,
37
+ PushMessagesResponse,
39
38
  )
40
39
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
41
40
  from flwr.server.superlink.ffs.ffs import Ffs
@@ -51,7 +50,7 @@ try:
51
50
  from starlette.responses import Response
52
51
  from starlette.routing import Route
53
52
  except ModuleNotFoundError:
54
- sys.exit(MISSING_EXTRA_REST)
53
+ flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
55
54
 
56
55
 
57
56
  GrpcRequest = TypeVar("GrpcRequest", bound=GrpcMessage)
@@ -107,25 +106,24 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
107
106
  return message_handler.delete_node(request=request, state=state)
108
107
 
109
108
 
110
- @rest_request_response(PullTaskInsRequest)
111
- async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
112
- """Pull TaskIns."""
109
+ @rest_request_response(PullMessagesRequest)
110
+ async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
111
+ """Pull PullMessages."""
113
112
  # Get state from app
114
113
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
115
114
 
116
115
  # Handle message
117
- return message_handler.pull_task_ins(request=request, state=state)
116
+ return message_handler.pull_messages(request=request, state=state)
118
117
 
119
118
 
120
- # Check if token is needed here
121
- @rest_request_response(PushTaskResRequest)
122
- async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
123
- """Push TaskRes."""
119
+ @rest_request_response(PushMessagesRequest)
120
+ async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
121
+ """Pull PushMessages."""
124
122
  # Get state from app
125
123
  state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
126
124
 
127
125
  # Handle message
128
- return message_handler.push_task_res(request=request, state=state)
126
+ return message_handler.push_messages(request=request, state=state)
129
127
 
130
128
 
131
129
  @rest_request_response(PingRequest)
@@ -164,8 +162,8 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
164
162
  routes = [
165
163
  Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
166
164
  Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
167
- Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
168
- Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
165
+ Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]),
166
+ Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]),
169
167
  Route("/api/v0/fleet/ping", ping, methods=["POST"]),
170
168
  Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
171
169
  Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
@@ -182,8 +182,8 @@ def run_api(
182
182
  f_stop: threading.Event,
183
183
  ) -> None:
184
184
  """Run the VCE."""
185
- taskins_queue: "Queue[TaskIns]" = Queue()
186
- taskres_queue: "Queue[TaskRes]" = Queue()
185
+ taskins_queue: Queue[TaskIns] = Queue()
186
+ taskres_queue: Queue[TaskRes] = Queue()
187
187
 
188
188
  try:
189
189