flwr-nightly 1.15.0.dev20250104__py3-none-any.whl → 1.15.0.dev20250123__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. flwr/cli/cli_user_auth_interceptor.py +6 -2
  2. flwr/cli/config_utils.py +23 -146
  3. flwr/cli/constant.py +27 -0
  4. flwr/cli/install.py +1 -1
  5. flwr/cli/log.py +17 -2
  6. flwr/cli/login/login.py +20 -5
  7. flwr/cli/ls.py +10 -2
  8. flwr/cli/run/run.py +20 -10
  9. flwr/cli/stop.py +9 -1
  10. flwr/cli/utils.py +4 -4
  11. flwr/client/app.py +36 -48
  12. flwr/client/clientapp/app.py +4 -6
  13. flwr/client/clientapp/utils.py +1 -1
  14. flwr/client/grpc_client/connection.py +0 -6
  15. flwr/client/grpc_rere_client/client_interceptor.py +19 -119
  16. flwr/client/grpc_rere_client/connection.py +34 -24
  17. flwr/client/grpc_rere_client/grpc_adapter.py +16 -0
  18. flwr/client/rest_client/connection.py +34 -26
  19. flwr/client/supernode/app.py +14 -20
  20. flwr/common/auth_plugin/auth_plugin.py +34 -23
  21. flwr/common/config.py +152 -15
  22. flwr/common/constant.py +11 -8
  23. flwr/common/exit/__init__.py +24 -0
  24. flwr/common/exit/exit.py +99 -0
  25. flwr/common/exit/exit_code.py +93 -0
  26. flwr/common/exit_handlers.py +24 -10
  27. flwr/common/grpc.py +161 -3
  28. flwr/common/logger.py +1 -1
  29. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
  30. flwr/common/serde.py +6 -4
  31. flwr/common/typing.py +20 -0
  32. flwr/proto/clientappio_pb2.py +13 -3
  33. flwr/proto/clientappio_pb2_grpc.py +63 -12
  34. flwr/proto/error_pb2.py +13 -3
  35. flwr/proto/error_pb2_grpc.py +20 -0
  36. flwr/proto/exec_pb2.py +27 -29
  37. flwr/proto/exec_pb2.pyi +27 -54
  38. flwr/proto/exec_pb2_grpc.py +105 -24
  39. flwr/proto/fab_pb2.py +13 -3
  40. flwr/proto/fab_pb2_grpc.py +20 -0
  41. flwr/proto/fleet_pb2.py +54 -31
  42. flwr/proto/fleet_pb2.pyi +84 -0
  43. flwr/proto/fleet_pb2_grpc.py +207 -28
  44. flwr/proto/fleet_pb2_grpc.pyi +26 -0
  45. flwr/proto/grpcadapter_pb2.py +14 -4
  46. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  47. flwr/proto/log_pb2.py +13 -3
  48. flwr/proto/log_pb2_grpc.py +20 -0
  49. flwr/proto/message_pb2.py +15 -5
  50. flwr/proto/message_pb2_grpc.py +20 -0
  51. flwr/proto/node_pb2.py +15 -5
  52. flwr/proto/node_pb2.pyi +1 -4
  53. flwr/proto/node_pb2_grpc.py +20 -0
  54. flwr/proto/recordset_pb2.py +18 -8
  55. flwr/proto/recordset_pb2_grpc.py +20 -0
  56. flwr/proto/run_pb2.py +16 -6
  57. flwr/proto/run_pb2_grpc.py +20 -0
  58. flwr/proto/serverappio_pb2.py +32 -14
  59. flwr/proto/serverappio_pb2.pyi +56 -0
  60. flwr/proto/serverappio_pb2_grpc.py +261 -44
  61. flwr/proto/serverappio_pb2_grpc.pyi +20 -0
  62. flwr/proto/simulationio_pb2.py +13 -3
  63. flwr/proto/simulationio_pb2_grpc.py +105 -24
  64. flwr/proto/task_pb2.py +13 -3
  65. flwr/proto/task_pb2_grpc.py +20 -0
  66. flwr/proto/transport_pb2.py +20 -10
  67. flwr/proto/transport_pb2_grpc.py +35 -4
  68. flwr/server/app.py +87 -38
  69. flwr/server/compat/app_utils.py +0 -1
  70. flwr/server/compat/driver_client_proxy.py +1 -2
  71. flwr/server/driver/grpc_driver.py +5 -2
  72. flwr/server/driver/inmemory_driver.py +2 -1
  73. flwr/server/serverapp/app.py +5 -6
  74. flwr/server/superlink/driver/serverappio_grpc.py +1 -1
  75. flwr/server/superlink/driver/serverappio_servicer.py +132 -14
  76. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
  77. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
  78. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +38 -0
  79. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +95 -168
  80. flwr/server/superlink/fleet/message_handler/message_handler.py +66 -5
  81. flwr/server/superlink/fleet/rest_rere/rest_api.py +28 -3
  82. flwr/server/superlink/fleet/vce/vce_api.py +2 -2
  83. flwr/server/superlink/linkstate/in_memory_linkstate.py +40 -48
  84. flwr/server/superlink/linkstate/linkstate.py +15 -22
  85. flwr/server/superlink/linkstate/sqlite_linkstate.py +80 -99
  86. flwr/server/superlink/linkstate/utils.py +18 -8
  87. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  88. flwr/server/utils/validator.py +9 -34
  89. flwr/simulation/app.py +4 -6
  90. flwr/simulation/legacy_app.py +4 -2
  91. flwr/simulation/run_simulation.py +1 -1
  92. flwr/superexec/exec_grpc.py +1 -1
  93. flwr/superexec/exec_servicer.py +23 -2
  94. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/METADATA +7 -7
  95. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/RECORD +98 -94
  96. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/LICENSE +0 -0
  97. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/WHEEL +0 -0
  98. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/entry_points.txt +0 -0
@@ -30,8 +30,12 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
30
30
  DeleteNodeResponse,
31
31
  PingRequest,
32
32
  PingResponse,
33
+ PullMessagesRequest,
34
+ PullMessagesResponse,
33
35
  PullTaskInsRequest,
34
36
  PullTaskInsResponse,
37
+ PushMessagesRequest,
38
+ PushMessagesResponse,
35
39
  PushTaskResRequest,
36
40
  PushTaskResResponse,
37
41
  )
@@ -95,6 +99,17 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
95
99
  state=self.state_factory.state(),
96
100
  )
97
101
 
102
+ def PullMessages(
103
+ self, request: PullMessagesRequest, context: grpc.ServicerContext
104
+ ) -> PullMessagesResponse:
105
+ """Pull Messages."""
106
+ log(INFO, "[Fleet.PullMessages] node_id=%s", request.node.node_id)
107
+ log(DEBUG, "[Fleet.PullMessages] Request: %s", request)
108
+ return message_handler.pull_messages(
109
+ request=request,
110
+ state=self.state_factory.state(),
111
+ )
112
+
98
113
  def PushTaskRes(
99
114
  self, request: PushTaskResRequest, context: grpc.ServicerContext
100
115
  ) -> PushTaskResResponse:
@@ -118,6 +133,29 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
118
133
 
119
134
  return res
120
135
 
136
+ def PushMessages(
137
+ self, request: PushMessagesRequest, context: grpc.ServicerContext
138
+ ) -> PushMessagesResponse:
139
+ """Push Messages."""
140
+ if request.messages_list:
141
+ log(
142
+ INFO,
143
+ "[Fleet.PushMessages] Push results from node_id=%s",
144
+ request.messages_list[0].metadata.src_node_id,
145
+ )
146
+ else:
147
+ log(INFO, "[Fleet.PushMessages] No task results to push")
148
+
149
+ try:
150
+ res = message_handler.push_messages(
151
+ request=request,
152
+ state=self.state_factory.state(),
153
+ )
154
+ except InvalidRunStatusException as e:
155
+ abort_grpc_context(e.message, context)
156
+
157
+ return res
158
+
121
159
  def GetRun(
122
160
  self, request: GetRunRequest, context: grpc.ServicerContext
123
161
  ) -> GetRunResponse:
@@ -15,91 +15,54 @@
15
15
  """Flower server interceptor."""
16
16
 
17
17
 
18
- import base64
19
- from collections.abc import Sequence
20
- from logging import INFO, WARNING
21
- from typing import Any, Callable, Optional, Union
18
+ import datetime
19
+ from typing import Any, Callable, Optional, cast
22
20
 
23
21
  import grpc
24
- from cryptography.hazmat.primitives.asymmetric import ec
25
-
26
- from flwr.common.logger import log
22
+ from google.protobuf.message import Message as GrpcMessage
23
+
24
+ from flwr.common import now
25
+ from flwr.common.constant import (
26
+ PUBLIC_KEY_HEADER,
27
+ SIGNATURE_HEADER,
28
+ TIMESTAMP_HEADER,
29
+ TIMESTAMP_TOLERANCE,
30
+ )
27
31
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
28
- bytes_to_private_key,
29
32
  bytes_to_public_key,
30
- generate_shared_key,
31
- verify_hmac,
33
+ verify_signature,
32
34
  )
33
- from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
34
35
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
35
36
  CreateNodeRequest,
36
37
  CreateNodeResponse,
37
- DeleteNodeRequest,
38
- DeleteNodeResponse,
39
- PingRequest,
40
- PingResponse,
41
- PullTaskInsRequest,
42
- PullTaskInsResponse,
43
- PushTaskResRequest,
44
- PushTaskResResponse,
45
38
  )
46
- from flwr.proto.node_pb2 import Node # pylint: disable=E0611
47
- from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
48
39
  from flwr.server.superlink.linkstate import LinkStateFactory
49
40
 
50
- _PUBLIC_KEY_HEADER = "public-key"
51
- _AUTH_TOKEN_HEADER = "auth-token"
52
-
53
- Request = Union[
54
- CreateNodeRequest,
55
- DeleteNodeRequest,
56
- PullTaskInsRequest,
57
- PushTaskResRequest,
58
- GetRunRequest,
59
- PingRequest,
60
- GetFabRequest,
61
- ]
62
-
63
- Response = Union[
64
- CreateNodeResponse,
65
- DeleteNodeResponse,
66
- PullTaskInsResponse,
67
- PushTaskResResponse,
68
- GetRunResponse,
69
- PingResponse,
70
- GetFabResponse,
71
- ]
72
-
73
41
 
74
- def _get_value_from_tuples(
75
- key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]]
76
- ) -> bytes:
77
- value = next((value for key, value in tuples if key == key_string), "")
78
- if isinstance(value, str):
79
- return value.encode()
42
+ def _unary_unary_rpc_terminator(message: str) -> grpc.RpcMethodHandler:
43
+ def terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMessage:
44
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, message)
45
+ raise RuntimeError("Should not reach this point") # Make mypy happy
80
46
 
81
- return value
47
+ return grpc.unary_unary_rpc_method_handler(terminate)
82
48
 
83
49
 
84
50
  class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
85
- """Server interceptor for node authentication."""
86
-
87
- def __init__(self, state_factory: LinkStateFactory):
51
+ """Server interceptor for node authentication.
52
+
53
+ Parameters
54
+ ----------
55
+ state_factory : LinkStateFactory
56
+ A factory for creating new instances of LinkState.
57
+ auto_auth : bool (default: False)
58
+ If True, nodes are authenticated without requiring their public keys to be
59
+ pre-stored in the LinkState. If False, only nodes with pre-stored public keys
60
+ can be authenticated.
61
+ """
62
+
63
+ def __init__(self, state_factory: LinkStateFactory, auto_auth: bool = False):
88
64
  self.state_factory = state_factory
89
- state = self.state_factory.state()
90
-
91
- self.node_public_keys = state.get_node_public_keys()
92
- if len(self.node_public_keys) == 0:
93
- log(WARNING, "Authentication enabled, but no known public keys configured")
94
-
95
- private_key = state.get_server_private_key()
96
- public_key = state.get_server_public_key()
97
-
98
- if private_key is None or public_key is None:
99
- raise ValueError("Error loading authentication keys")
100
-
101
- self.server_private_key = bytes_to_private_key(private_key)
102
- self.encoded_server_public_key = base64.urlsafe_b64encode(public_key)
65
+ self.auto_auth = auto_auth
103
66
 
104
67
  def intercept_service(
105
68
  self,
@@ -112,116 +75,80 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
112
75
  metadata sent by the node. Continue RPC call if node is authenticated, else,
113
76
  terminate RPC call by setting context to abort.
114
77
  """
78
+ state = self.state_factory.state()
79
+ metadata_dict = dict(handler_call_details.invocation_metadata)
80
+
81
+ # Retrieve info from the metadata
82
+ try:
83
+ node_pk_bytes = cast(bytes, metadata_dict[PUBLIC_KEY_HEADER])
84
+ timestamp_iso = cast(str, metadata_dict[TIMESTAMP_HEADER])
85
+ signature = cast(bytes, metadata_dict[SIGNATURE_HEADER])
86
+ except KeyError:
87
+ return _unary_unary_rpc_terminator("Missing authentication metadata")
88
+
89
+ if not self.auto_auth:
90
+ # Abort the RPC call if the node public key is not found
91
+ if node_pk_bytes not in state.get_node_public_keys():
92
+ return _unary_unary_rpc_terminator("Public key not recognized")
93
+
94
+ # Verify the signature
95
+ node_pk = bytes_to_public_key(node_pk_bytes)
96
+ if not verify_signature(node_pk, timestamp_iso.encode("ascii"), signature):
97
+ return _unary_unary_rpc_terminator("Invalid signature")
98
+
99
+ # Verify the timestamp
100
+ current = now()
101
+ time_diff = current - datetime.datetime.fromisoformat(timestamp_iso)
102
+ # Abort the RPC call if the timestamp is too old or in the future
103
+ if not 0 < time_diff.total_seconds() < TIMESTAMP_TOLERANCE:
104
+ return _unary_unary_rpc_terminator("Invalid timestamp")
105
+
106
+ # Continue the RPC call
107
+ expected_node_id = state.get_node_id(node_pk_bytes)
108
+ if not handler_call_details.method.endswith("CreateNode"):
109
+ if expected_node_id is None:
110
+ return _unary_unary_rpc_terminator("Invalid node ID")
115
111
  # One of the method handlers in
116
112
  # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer`
117
113
  method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
118
- return self._generic_auth_unary_method_handler(method_handler)
114
+ return self._wrap_method_handler(
115
+ method_handler, expected_node_id, node_pk_bytes
116
+ )
119
117
 
120
- def _generic_auth_unary_method_handler(
121
- self, method_handler: grpc.RpcMethodHandler
118
+ def _wrap_method_handler(
119
+ self,
120
+ method_handler: grpc.RpcMethodHandler,
121
+ expected_node_id: Optional[int],
122
+ node_public_key: bytes,
122
123
  ) -> grpc.RpcMethodHandler:
123
124
  def _generic_method_handler(
124
- request: Request,
125
+ request: GrpcMessage,
125
126
  context: grpc.ServicerContext,
126
- ) -> Response:
127
- node_public_key_bytes = base64.urlsafe_b64decode(
128
- _get_value_from_tuples(
129
- _PUBLIC_KEY_HEADER, context.invocation_metadata()
130
- )
131
- )
132
- if node_public_key_bytes not in self.node_public_keys:
133
- context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
134
-
135
- if isinstance(request, CreateNodeRequest):
136
- response = self._create_authenticated_node(
137
- node_public_key_bytes, request, context
138
- )
139
- log(
140
- INFO,
141
- "AuthenticateServerInterceptor: Created node_id=%s",
142
- response.node.node_id,
143
- )
144
- return response
145
-
146
- # Verify hmac value
147
- hmac_value = base64.urlsafe_b64decode(
148
- _get_value_from_tuples(
149
- _AUTH_TOKEN_HEADER, context.invocation_metadata()
150
- )
151
- )
152
- public_key = bytes_to_public_key(node_public_key_bytes)
153
-
154
- if not self._verify_hmac(public_key, request, hmac_value):
155
- context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
156
-
157
- # Verify node_id
158
- node_id = self.state_factory.state().get_node_id(node_public_key_bytes)
159
-
160
- if not self._verify_node_id(node_id, request):
161
- context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
162
-
163
- return method_handler.unary_unary(request, context) # type: ignore
127
+ ) -> GrpcMessage:
128
+ # Verify the node ID
129
+ if not isinstance(request, CreateNodeRequest):
130
+ try:
131
+ if request.node.node_id != expected_node_id: # type: ignore
132
+ raise ValueError
133
+ except (AttributeError, ValueError):
134
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID")
135
+
136
+ response: GrpcMessage = method_handler.unary_unary(request, context)
137
+
138
+ # Set the public key after a successful CreateNode request
139
+ if isinstance(response, CreateNodeResponse):
140
+ state = self.state_factory.state()
141
+ try:
142
+ state.set_node_public_key(response.node.node_id, node_public_key)
143
+ except ValueError as e:
144
+ # Remove newly created node if setting the public key fails
145
+ state.delete_node(response.node.node_id)
146
+ context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e))
147
+
148
+ return response
164
149
 
165
150
  return grpc.unary_unary_rpc_method_handler(
166
151
  _generic_method_handler,
167
152
  request_deserializer=method_handler.request_deserializer,
168
153
  response_serializer=method_handler.response_serializer,
169
154
  )
170
-
171
- def _verify_node_id(
172
- self,
173
- node_id: Optional[int],
174
- request: Union[
175
- DeleteNodeRequest,
176
- PullTaskInsRequest,
177
- PushTaskResRequest,
178
- GetRunRequest,
179
- PingRequest,
180
- GetFabRequest,
181
- ],
182
- ) -> bool:
183
- if node_id is None:
184
- return False
185
- if isinstance(request, PushTaskResRequest):
186
- if len(request.task_res_list) == 0:
187
- return False
188
- return request.task_res_list[0].task.producer.node_id == node_id
189
- if isinstance(request, GetRunRequest):
190
- return node_id in self.state_factory.state().get_nodes(request.run_id)
191
- return request.node.node_id == node_id
192
-
193
- def _verify_hmac(
194
- self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes
195
- ) -> bool:
196
- shared_secret = generate_shared_key(self.server_private_key, public_key)
197
- message_bytes = request.SerializeToString(deterministic=True)
198
- return verify_hmac(shared_secret, message_bytes, hmac_value)
199
-
200
- def _create_authenticated_node(
201
- self,
202
- public_key_bytes: bytes,
203
- request: CreateNodeRequest,
204
- context: grpc.ServicerContext,
205
- ) -> CreateNodeResponse:
206
- context.send_initial_metadata(
207
- (
208
- (
209
- _PUBLIC_KEY_HEADER,
210
- self.encoded_server_public_key,
211
- ),
212
- )
213
- )
214
- state = self.state_factory.state()
215
- node_id = state.get_node_id(public_key_bytes)
216
-
217
- # Handle `CreateNode` here instead of calling the default method handler
218
- # Return previously assigned `node_id` for the provided `public_key`
219
- if node_id is not None:
220
- state.acknowledge_ping(node_id, request.ping_interval)
221
- return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
222
-
223
- # No `node_id` exists for the provided `public_key`
224
- # Handle `CreateNode` here instead of calling the default method handler
225
- # Note: the innermost `CreateNode` method will never be called
226
- node_id = state.create_node(request.ping_interval, public_key_bytes)
227
- return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
@@ -20,7 +20,14 @@ from typing import Optional
20
20
  from uuid import UUID
21
21
 
22
22
  from flwr.common.constant import Status
23
- from flwr.common.serde import fab_to_proto, user_config_to_proto
23
+ from flwr.common.serde import (
24
+ fab_to_proto,
25
+ message_from_proto,
26
+ message_from_taskins,
27
+ message_to_proto,
28
+ message_to_taskres,
29
+ user_config_to_proto,
30
+ )
24
31
  from flwr.common.typing import Fab, InvalidRunStatusException
25
32
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
26
33
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
@@ -30,8 +37,12 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
30
37
  DeleteNodeResponse,
31
38
  PingRequest,
32
39
  PingResponse,
40
+ PullMessagesRequest,
41
+ PullMessagesResponse,
33
42
  PullTaskInsRequest,
34
43
  PullTaskInsResponse,
44
+ PushMessagesRequest,
45
+ PushMessagesResponse,
35
46
  PushTaskResRequest,
36
47
  PushTaskResResponse,
37
48
  Reconnect,
@@ -55,13 +66,13 @@ def create_node(
55
66
  """."""
56
67
  # Create node
57
68
  node_id = state.create_node(ping_interval=request.ping_interval)
58
- return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))
69
+ return CreateNodeResponse(node=Node(node_id=node_id))
59
70
 
60
71
 
61
72
  def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
62
73
  """."""
63
74
  # Validate node_id
64
- if request.node.anonymous or request.node.node_id == 0:
75
+ if request.node.node_id == 0: # i.e. unset `node_id`
65
76
  return DeleteNodeResponse()
66
77
 
67
78
  # Update state
@@ -80,9 +91,8 @@ def ping(
80
91
 
81
92
  def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
82
93
  """Pull TaskIns handler."""
83
- # Get node_id if client node is not anonymous
84
94
  node = request.node # pylint: disable=no-member
85
- node_id: Optional[int] = None if node.anonymous else node.node_id
95
+ node_id: int = node.node_id
86
96
 
87
97
  # Retrieve TaskIns from State
88
98
  task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)
@@ -94,6 +104,26 @@ def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsR
94
104
  return response
95
105
 
96
106
 
107
+ def pull_messages(
108
+ request: PullMessagesRequest, state: LinkState
109
+ ) -> PullMessagesResponse:
110
+ """Pull Messages handler."""
111
+ # Get node_id if client node is not anonymous
112
+ node = request.node # pylint: disable=no-member
113
+ node_id: int = node.node_id
114
+
115
+ # Retrieve TaskIns from State
116
+ task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1)
117
+
118
+ # Convert to Messages
119
+ msg_proto = []
120
+ for task_ins in task_ins_list:
121
+ msg = message_from_taskins(task_ins)
122
+ msg_proto.append(message_to_proto(msg))
123
+
124
+ return PullMessagesResponse(messages_list=msg_proto)
125
+
126
+
97
127
  def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse:
98
128
  """Push TaskRes handler."""
99
129
  # pylint: disable=no-member
@@ -123,6 +153,37 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
123
153
  return response
124
154
 
125
155
 
156
+ def push_messages(
157
+ request: PushMessagesRequest, state: LinkState
158
+ ) -> PushMessagesResponse:
159
+ """Push Messages handler."""
160
+ # Convert Message to TaskRes
161
+ msg = message_from_proto(message_proto=request.messages_list[0])
162
+ task_res = message_to_taskres(msg)
163
+
164
+ # Abort if the run is not running
165
+ abort_msg = check_abort(
166
+ task_res.run_id,
167
+ [Status.PENDING, Status.STARTING, Status.FINISHED],
168
+ state,
169
+ )
170
+ if abort_msg:
171
+ raise InvalidRunStatusException(abort_msg)
172
+
173
+ # Set pushed_at (timestamp in seconds)
174
+ task_res.task.pushed_at = time.time()
175
+
176
+ # Store TaskRes in State
177
+ message_id: Optional[UUID] = state.store_task_res(task_res=task_res)
178
+
179
+ # Build response
180
+ response = PushMessagesResponse(
181
+ reconnect=Reconnect(reconnect=5),
182
+ results={str(message_id): 0},
183
+ )
184
+ return response
185
+
186
+
126
187
  def get_run(request: GetRunRequest, state: LinkState) -> GetRunResponse:
127
188
  """Get run information."""
128
189
  run = state.get_run(request.run_id)
@@ -17,13 +17,12 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import sys
21
20
  from collections.abc import Awaitable
22
21
  from typing import Callable, TypeVar, cast
23
22
 
24
23
  from google.protobuf.message import Message as GrpcMessage
25
24
 
26
- from flwr.common.constant import MISSING_EXTRA_REST
25
+ from flwr.common.exit import ExitCode, flwr_exit
27
26
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
28
27
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
29
28
  CreateNodeRequest,
@@ -32,8 +31,12 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
31
  DeleteNodeResponse,
33
32
  PingRequest,
34
33
  PingResponse,
34
+ PullMessagesRequest,
35
+ PullMessagesResponse,
35
36
  PullTaskInsRequest,
36
37
  PullTaskInsResponse,
38
+ PushMessagesRequest,
39
+ PushMessagesResponse,
37
40
  PushTaskResRequest,
38
41
  PushTaskResResponse,
39
42
  )
@@ -51,7 +54,7 @@ try:
51
54
  from starlette.responses import Response
52
55
  from starlette.routing import Route
53
56
  except ModuleNotFoundError:
54
- sys.exit(MISSING_EXTRA_REST)
57
+ flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
55
58
 
56
59
 
57
60
  GrpcRequest = TypeVar("GrpcRequest", bound=GrpcMessage)
@@ -117,6 +120,16 @@ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
117
120
  return message_handler.pull_task_ins(request=request, state=state)
118
121
 
119
122
 
123
+ @rest_request_response(PullMessagesRequest)
124
+ async def pull_message(request: PullMessagesRequest) -> PullMessagesResponse:
125
+ """Pull PullMessages."""
126
+ # Get state from app
127
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
128
+
129
+ # Handle message
130
+ return message_handler.pull_messages(request=request, state=state)
131
+
132
+
120
133
  # Check if token is needed here
121
134
  @rest_request_response(PushTaskResRequest)
122
135
  async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
@@ -128,6 +141,16 @@ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
128
141
  return message_handler.push_task_res(request=request, state=state)
129
142
 
130
143
 
144
+ @rest_request_response(PushMessagesRequest)
145
+ async def push_message(request: PushMessagesRequest) -> PushMessagesResponse:
146
+ """Pull PushMessages."""
147
+ # Get state from app
148
+ state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()
149
+
150
+ # Handle message
151
+ return message_handler.push_messages(request=request, state=state)
152
+
153
+
131
154
  @rest_request_response(PingRequest)
132
155
  async def ping(request: PingRequest) -> PingResponse:
133
156
  """Ping."""
@@ -165,7 +188,9 @@ routes = [
165
188
  Route("/api/v0/fleet/create-node", create_node, methods=["POST"]),
166
189
  Route("/api/v0/fleet/delete-node", delete_node, methods=["POST"]),
167
190
  Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
191
+ Route("/api/v0/fleet/pull-messages", pull_message, methods=["POST"]),
168
192
  Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
193
+ Route("/api/v0/fleet/push-messages", push_message, methods=["POST"]),
169
194
  Route("/api/v0/fleet/ping", ping, methods=["POST"]),
170
195
  Route("/api/v0/fleet/get-run", get_run, methods=["POST"]),
171
196
  Route("/api/v0/fleet/get-fab", get_fab, methods=["POST"]),
@@ -182,8 +182,8 @@ def run_api(
182
182
  f_stop: threading.Event,
183
183
  ) -> None:
184
184
  """Run the VCE."""
185
- taskins_queue: "Queue[TaskIns]" = Queue()
186
- taskres_queue: "Queue[TaskRes]" = Queue()
185
+ taskins_queue: Queue[TaskIns] = Queue()
186
+ taskres_queue: Queue[TaskRes] = Queue()
187
187
 
188
188
  try:
189
189