flwr 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. flwr/app/__init__.py +15 -0
  2. flwr/app/error.py +68 -0
  3. flwr/app/metadata.py +223 -0
  4. flwr/cli/build.py +94 -59
  5. flwr/cli/log.py +3 -3
  6. flwr/cli/login/login.py +3 -7
  7. flwr/cli/ls.py +15 -36
  8. flwr/cli/new/new.py +12 -4
  9. flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
  10. flwr/cli/new/templates/app/README.md.tpl +5 -0
  11. flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
  12. flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
  13. flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
  14. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +25 -17
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
  16. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
  17. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
  18. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
  19. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
  20. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
  21. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
  22. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
  23. flwr/cli/run/run.py +48 -49
  24. flwr/cli/stop.py +2 -2
  25. flwr/cli/utils.py +38 -5
  26. flwr/client/__init__.py +2 -2
  27. flwr/client/client_app.py +1 -1
  28. flwr/client/clientapp/__init__.py +0 -7
  29. flwr/client/grpc_adapter_client/connection.py +15 -8
  30. flwr/client/grpc_rere_client/connection.py +142 -97
  31. flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
  32. flwr/client/message_handler/message_handler.py +1 -1
  33. flwr/client/mod/comms_mods.py +36 -17
  34. flwr/client/rest_client/connection.py +176 -103
  35. flwr/clientapp/__init__.py +15 -0
  36. flwr/common/__init__.py +2 -2
  37. flwr/common/auth_plugin/__init__.py +2 -0
  38. flwr/common/auth_plugin/auth_plugin.py +29 -3
  39. flwr/common/constant.py +39 -8
  40. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  41. flwr/common/exit/exit_code.py +16 -1
  42. flwr/common/exit_handlers.py +30 -0
  43. flwr/common/grpc.py +12 -1
  44. flwr/common/heartbeat.py +165 -0
  45. flwr/common/inflatable.py +290 -0
  46. flwr/common/inflatable_protobuf_utils.py +141 -0
  47. flwr/common/inflatable_utils.py +508 -0
  48. flwr/common/message.py +110 -242
  49. flwr/common/record/__init__.py +2 -1
  50. flwr/common/record/array.py +402 -0
  51. flwr/common/record/arraychunk.py +59 -0
  52. flwr/common/record/arrayrecord.py +103 -225
  53. flwr/common/record/configrecord.py +59 -4
  54. flwr/common/record/conversion_utils.py +1 -1
  55. flwr/common/record/metricrecord.py +55 -4
  56. flwr/common/record/recorddict.py +69 -1
  57. flwr/common/recorddict_compat.py +2 -2
  58. flwr/common/retry_invoker.py +5 -1
  59. flwr/common/serde.py +59 -211
  60. flwr/common/serde_utils.py +175 -0
  61. flwr/common/typing.py +5 -3
  62. flwr/compat/__init__.py +15 -0
  63. flwr/compat/client/__init__.py +15 -0
  64. flwr/{client → compat/client}/app.py +28 -185
  65. flwr/compat/common/__init__.py +15 -0
  66. flwr/compat/server/__init__.py +15 -0
  67. flwr/compat/server/app.py +174 -0
  68. flwr/compat/simulation/__init__.py +15 -0
  69. flwr/proto/appio_pb2.py +43 -0
  70. flwr/proto/appio_pb2.pyi +151 -0
  71. flwr/proto/appio_pb2_grpc.py +4 -0
  72. flwr/proto/appio_pb2_grpc.pyi +4 -0
  73. flwr/proto/clientappio_pb2.py +12 -19
  74. flwr/proto/clientappio_pb2.pyi +23 -101
  75. flwr/proto/clientappio_pb2_grpc.py +269 -28
  76. flwr/proto/clientappio_pb2_grpc.pyi +114 -20
  77. flwr/proto/fleet_pb2.py +24 -27
  78. flwr/proto/fleet_pb2.pyi +19 -35
  79. flwr/proto/fleet_pb2_grpc.py +117 -13
  80. flwr/proto/fleet_pb2_grpc.pyi +47 -6
  81. flwr/proto/heartbeat_pb2.py +33 -0
  82. flwr/proto/heartbeat_pb2.pyi +66 -0
  83. flwr/proto/heartbeat_pb2_grpc.py +4 -0
  84. flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
  85. flwr/proto/message_pb2.py +28 -11
  86. flwr/proto/message_pb2.pyi +125 -0
  87. flwr/proto/recorddict_pb2.py +16 -28
  88. flwr/proto/recorddict_pb2.pyi +46 -64
  89. flwr/proto/run_pb2.py +24 -32
  90. flwr/proto/run_pb2.pyi +4 -52
  91. flwr/proto/serverappio_pb2.py +9 -23
  92. flwr/proto/serverappio_pb2.pyi +0 -110
  93. flwr/proto/serverappio_pb2_grpc.py +177 -72
  94. flwr/proto/serverappio_pb2_grpc.pyi +75 -33
  95. flwr/proto/simulationio_pb2.py +12 -11
  96. flwr/proto/simulationio_pb2_grpc.py +35 -0
  97. flwr/proto/simulationio_pb2_grpc.pyi +14 -0
  98. flwr/server/__init__.py +1 -1
  99. flwr/server/app.py +69 -187
  100. flwr/server/compat/app_utils.py +50 -28
  101. flwr/server/fleet_event_log_interceptor.py +6 -2
  102. flwr/server/grid/grpc_grid.py +148 -41
  103. flwr/server/grid/inmemory_grid.py +5 -4
  104. flwr/server/serverapp/app.py +45 -17
  105. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +21 -3
  106. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
  107. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
  108. flwr/server/superlink/fleet/message_handler/message_handler.py +130 -19
  109. flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -13
  110. flwr/server/superlink/fleet/vce/vce_api.py +6 -3
  111. flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
  112. flwr/server/superlink/linkstate/linkstate.py +53 -20
  113. flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
  114. flwr/server/superlink/linkstate/utils.py +33 -29
  115. flwr/server/superlink/serverappio/serverappio_grpc.py +4 -1
  116. flwr/server/superlink/serverappio/serverappio_servicer.py +230 -84
  117. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  118. flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
  119. flwr/server/superlink/utils.py +9 -2
  120. flwr/server/utils/validator.py +2 -2
  121. flwr/serverapp/__init__.py +15 -0
  122. flwr/simulation/app.py +25 -0
  123. flwr/simulation/run_simulation.py +17 -0
  124. flwr/supercore/__init__.py +15 -0
  125. flwr/{server/superlink → supercore}/ffs/__init__.py +2 -0
  126. flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
  127. flwr/supercore/grpc_health/__init__.py +22 -0
  128. flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
  129. flwr/supercore/license_plugin/__init__.py +22 -0
  130. flwr/supercore/license_plugin/license_plugin.py +26 -0
  131. flwr/supercore/object_store/__init__.py +24 -0
  132. flwr/supercore/object_store/in_memory_object_store.py +229 -0
  133. flwr/supercore/object_store/object_store.py +170 -0
  134. flwr/supercore/object_store/object_store_factory.py +44 -0
  135. flwr/supercore/object_store/utils.py +43 -0
  136. flwr/supercore/scheduler/__init__.py +22 -0
  137. flwr/supercore/scheduler/plugin.py +71 -0
  138. flwr/{client/nodestate/nodestate.py → supercore/utils.py} +14 -13
  139. flwr/superexec/deployment.py +7 -4
  140. flwr/superexec/exec_event_log_interceptor.py +8 -4
  141. flwr/superexec/exec_grpc.py +25 -5
  142. flwr/superexec/exec_license_interceptor.py +82 -0
  143. flwr/superexec/exec_servicer.py +135 -24
  144. flwr/superexec/exec_user_auth_interceptor.py +45 -8
  145. flwr/superexec/executor.py +5 -1
  146. flwr/superexec/simulation.py +8 -3
  147. flwr/superlink/__init__.py +15 -0
  148. flwr/{client/supernode → supernode}/__init__.py +0 -7
  149. flwr/supernode/cli/__init__.py +24 -0
  150. flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -19
  151. flwr/supernode/cli/flwr_clientapp.py +88 -0
  152. flwr/supernode/nodestate/in_memory_nodestate.py +199 -0
  153. flwr/supernode/nodestate/nodestate.py +227 -0
  154. flwr/supernode/runtime/__init__.py +15 -0
  155. flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +135 -89
  156. flwr/supernode/scheduler/__init__.py +22 -0
  157. flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
  158. flwr/supernode/servicer/__init__.py +15 -0
  159. flwr/supernode/servicer/clientappio/__init__.py +22 -0
  160. flwr/supernode/servicer/clientappio/clientappio_servicer.py +303 -0
  161. flwr/supernode/start_client_internal.py +589 -0
  162. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/METADATA +6 -4
  163. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/RECORD +171 -123
  164. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +1 -1
  165. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +2 -2
  166. flwr/client/clientapp/clientappio_servicer.py +0 -244
  167. flwr/client/heartbeat.py +0 -74
  168. flwr/client/nodestate/in_memory_nodestate.py +0 -38
  169. /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
  170. /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
  171. /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
  172. /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
  173. /flwr/{client → supernode}/nodestate/__init__.py +0 -0
  174. /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
@@ -22,26 +22,51 @@ from typing import Optional, cast
22
22
 
23
23
  import grpc
24
24
 
25
- from flwr.common import Message, RecordDict
25
+ from flwr.app.error import Error
26
+ from flwr.common import Message, Metadata, RecordDict, now
26
27
  from flwr.common.constant import (
27
28
  SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
28
29
  SUPERLINK_NODE_ID,
30
+ ErrorCode,
31
+ MessageType,
29
32
  )
30
33
  from flwr.common.grpc import create_channel, on_channel_state_change
34
+ from flwr.common.inflatable import (
35
+ InflatableObject,
36
+ get_all_nested_objects,
37
+ get_object_tree,
38
+ iterate_object_tree,
39
+ no_object_id_recompute,
40
+ )
41
+ from flwr.common.inflatable_protobuf_utils import (
42
+ make_pull_object_fn_protobuf,
43
+ make_push_object_fn_protobuf,
44
+ )
45
+ from flwr.common.inflatable_utils import (
46
+ ObjectUnavailableError,
47
+ inflate_object_from_contents,
48
+ pull_objects,
49
+ push_objects,
50
+ )
31
51
  from flwr.common.logger import log, warn_deprecated_feature
52
+ from flwr.common.message import make_message, remove_content_from_message
32
53
  from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
33
- from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
54
+ from flwr.common.serde import message_to_proto, run_from_proto
34
55
  from flwr.common.typing import Run
35
- from flwr.proto.message_pb2 import Message as ProtoMessage # pylint: disable=E0611
56
+ from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
57
+ PullAppMessagesRequest,
58
+ PullAppMessagesResponse,
59
+ PushAppMessagesRequest,
60
+ PushAppMessagesResponse,
61
+ )
62
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
63
+ ConfirmMessageReceivedRequest,
64
+ )
36
65
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
37
66
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
38
67
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
39
68
  GetNodesRequest,
40
69
  GetNodesResponse,
41
- PullResMessagesRequest,
42
- PullResMessagesResponse,
43
- PushInsMessagesRequest,
44
- PushInsMessagesResponse,
45
70
  )
46
71
  from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
47
72
 
@@ -163,7 +188,7 @@ class GrpcGrid(Grid):
163
188
  def _check_message(self, message: Message) -> None:
164
189
  # Check if the message is valid
165
190
  if not (
166
- message.metadata.message_id == ""
191
+ message.metadata.message_id != ""
167
192
  and message.metadata.reply_to_message_id == ""
168
193
  and message.metadata.ttl > 0
169
194
  ):
@@ -198,6 +223,39 @@ class GrpcGrid(Grid):
198
223
  )
199
224
  return [node.node_id for node in res.nodes]
200
225
 
226
+ def _try_push_messages(self, run_id: int, messages: Iterable[Message]) -> list[str]:
227
+ """Push all messages and its associated objects."""
228
+ # Prepare all Messages to be sent in a single request
229
+ proto_messages = []
230
+ object_trees = []
231
+ all_objects: dict[str, InflatableObject] = {}
232
+ for msg in messages:
233
+ proto_messages.append(message_to_proto(remove_content_from_message(msg)))
234
+ all_objects.update(get_all_nested_objects(msg))
235
+ object_trees.append(get_object_tree(msg))
236
+ del msg
237
+
238
+ # Call GrpcServerAppIoStub method
239
+ res: PushAppMessagesResponse = self._stub.PushMessages(
240
+ PushAppMessagesRequest(
241
+ messages_list=proto_messages,
242
+ run_id=run_id,
243
+ message_object_trees=object_trees,
244
+ )
245
+ )
246
+
247
+ # Push objects
248
+ push_objects(
249
+ all_objects,
250
+ push_object_fn=make_push_object_fn_protobuf(
251
+ push_object_protobuf=self._stub.PushObject,
252
+ node=self.node,
253
+ run_id=run_id,
254
+ ),
255
+ object_ids_to_push=set(res.objects_to_push),
256
+ )
257
+ return cast(list[str], res.message_ids)
258
+
201
259
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
202
260
  """Push messages to specified node IDs.
203
261
 
@@ -206,57 +264,106 @@ class GrpcGrid(Grid):
206
264
  """
207
265
  # Construct Messages
208
266
  run_id = cast(Run, self._run).run_id
209
- message_proto_list: list[ProtoMessage] = []
210
- for msg in messages:
211
- # Populate metadata
212
- msg.metadata.__dict__["_run_id"] = run_id
213
- msg.metadata.__dict__["_src_node_id"] = self.node.node_id
214
- # Check message
215
- self._check_message(msg)
216
- # Convert to proto
217
- msg_proto = message_to_proto(msg)
218
- # Add to list
219
- message_proto_list.append(msg_proto)
220
-
267
+ message_ids: list[str] = []
221
268
  try:
222
- # Call GrpcServerAppIoStub method
223
- res: PushInsMessagesResponse = self._stub.PushMessages(
224
- PushInsMessagesRequest(messages_list=message_proto_list, run_id=run_id)
225
- )
226
- if len([msg_id for msg_id in res.message_ids if msg_id]) != len(
227
- message_proto_list
228
- ):
229
- log(
230
- WARNING,
231
- "Not all messages could be pushed to the SuperLink. The returned "
232
- "list has `None` for those messages (the order is preserved as "
233
- "passed to `push_messages`). This could be due to a malformed "
234
- "message.",
235
- )
236
- return list(res.message_ids)
269
+ with no_object_id_recompute():
270
+ for msg in messages:
271
+ # Populate metadata
272
+ msg.metadata.__dict__["_run_id"] = run_id
273
+ msg.metadata.__dict__["_src_node_id"] = self.node.node_id
274
+ msg.metadata.__dict__["_message_id"] = msg.object_id
275
+ # Check message
276
+ self._check_message(msg)
277
+ # Try pushing messages and their objects
278
+ message_ids = self._try_push_messages(run_id, messages)
279
+
237
280
  except grpc.RpcError as e:
238
281
  if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
239
282
  log(ERROR, ERROR_MESSAGE_PUSH_MESSAGES_RESOURCE_EXHAUSTED)
240
283
  return []
241
284
  raise
242
285
 
286
+ if None in message_ids:
287
+ log(
288
+ WARNING,
289
+ "Not all messages could be pushed to the SuperLink. The returned "
290
+ "list has `None` for those messages (the order is preserved as "
291
+ "passed to `push_messages`). This could be due to a malformed "
292
+ "message.",
293
+ )
294
+
295
+ return message_ids
296
+
243
297
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
244
298
  """Pull messages based on message IDs.
245
299
 
246
300
  This method is used to collect messages from the SuperLink that correspond to a
247
301
  set of given message IDs.
248
302
  """
303
+ run_id = cast(Run, self._run).run_id
249
304
  try:
250
305
  # Pull Messages
251
- res: PullResMessagesResponse = self._stub.PullMessages(
252
- PullResMessagesRequest(
306
+ res: PullAppMessagesResponse = self._stub.PullMessages(
307
+ PullAppMessagesRequest(
253
308
  message_ids=message_ids,
254
- run_id=cast(Run, self._run).run_id,
309
+ run_id=run_id,
255
310
  )
256
311
  )
257
- # Convert Message from Protobuf representation
258
- msgs = [message_from_proto(msg_proto) for msg_proto in res.messages_list]
259
- return msgs
312
+ # Pull Messages from store
313
+ inflated_msgs: list[Message] = []
314
+ for msg_proto, msg_tree in zip(res.messages_list, res.message_object_trees):
315
+ msg_id = msg_proto.metadata.message_id
316
+ try:
317
+ all_object_contents = pull_objects(
318
+ object_ids=[
319
+ tree.object_id for tree in iterate_object_tree(msg_tree)
320
+ ],
321
+ pull_object_fn=make_pull_object_fn_protobuf(
322
+ pull_object_protobuf=self._stub.PullObject,
323
+ node=self.node,
324
+ run_id=run_id,
325
+ ),
326
+ )
327
+ except ObjectUnavailableError as e:
328
+ # An ObjectUnavailableError indicates that the object is not yet
329
+ # available. If this point has been reached, it means that the
330
+ # Grid has tried to pull the object for the maximum number of times
331
+ # or for the maximum time allowed, so we return an inflated message
332
+ # with an error
333
+ inflated_msgs.append(
334
+ make_message(
335
+ metadata=Metadata(
336
+ run_id=run_id,
337
+ message_id="",
338
+ src_node_id=self.node.node_id,
339
+ dst_node_id=self.node.node_id,
340
+ message_type=MessageType.SYSTEM,
341
+ group_id="",
342
+ ttl=0,
343
+ reply_to_message_id=msg_proto.metadata.reply_to_message_id,
344
+ created_at=now().timestamp(),
345
+ ),
346
+ error=Error(
347
+ code=ErrorCode.MESSAGE_UNAVAILABLE, reason=(str(e))
348
+ ),
349
+ )
350
+ )
351
+ continue
352
+
353
+ # Confirm that the message has been received
354
+ self._stub.ConfirmMessageReceived(
355
+ ConfirmMessageReceivedRequest(
356
+ node=self.node, run_id=run_id, message_object_id=msg_id
357
+ )
358
+ )
359
+ message = cast(
360
+ Message, inflate_object_from_contents(msg_id, all_object_contents)
361
+ )
362
+ message.metadata.__dict__["_message_id"] = msg_id
363
+ inflated_msgs.append(message)
364
+
365
+ return inflated_msgs
366
+
260
367
  except grpc.RpcError as e:
261
368
  if e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED: # pylint: disable=E1101
262
369
  log(ERROR, ERROR_MESSAGE_PULL_MESSAGES_RESOURCE_EXHAUSTED)
@@ -18,7 +18,7 @@
18
18
  import time
19
19
  from collections.abc import Iterable
20
20
  from typing import Optional, cast
21
- from uuid import UUID
21
+ from uuid import uuid4
22
22
 
23
23
  from flwr.common import Message, RecordDict
24
24
  from flwr.common.constant import SUPERLINK_NODE_ID
@@ -56,7 +56,7 @@ class InMemoryGrid(Grid):
56
56
  def _check_message(self, message: Message) -> None:
57
57
  # Check if the message is valid
58
58
  if not (
59
- message.metadata.message_id == ""
59
+ message.metadata.message_id != ""
60
60
  and message.metadata.reply_to_message_id == ""
61
61
  and message.metadata.ttl > 0
62
62
  and message.metadata.delivered_at == ""
@@ -111,6 +111,7 @@ class InMemoryGrid(Grid):
111
111
  # Populate metadata
112
112
  msg.metadata.__dict__["_run_id"] = cast(Run, self._run).run_id
113
113
  msg.metadata.__dict__["_src_node_id"] = self.node.node_id
114
+ msg.metadata.__dict__["_message_id"] = str(uuid4())
114
115
  # Check message
115
116
  self._check_message(msg)
116
117
  # Store in state
@@ -126,12 +127,12 @@ class InMemoryGrid(Grid):
126
127
  This method is used to collect messages from the SuperLink that correspond to a
127
128
  set of given message IDs.
128
129
  """
129
- msg_ids = {UUID(msg_id) for msg_id in message_ids}
130
+ msg_ids = set(message_ids)
130
131
  # Pull Messages
131
132
  message_res_list = self.state.get_message_res(message_ids=msg_ids)
132
133
  # Get IDs of Messages these replies are for
133
134
  message_ins_ids_to_delete = {
134
- UUID(msg_res.metadata.reply_to_message_id) for msg_res in message_res_list
135
+ msg_res.metadata.reply_to_message_id for msg_res in message_res_list
135
136
  }
136
137
  # Delete
137
138
  self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
@@ -16,6 +16,7 @@
16
16
 
17
17
 
18
18
  import argparse
19
+ import gc
19
20
  from logging import DEBUG, ERROR, INFO
20
21
  from pathlib import Path
21
22
  from queue import Queue
@@ -38,6 +39,7 @@ from flwr.common.constant import (
38
39
  SubStatus,
39
40
  )
40
41
  from flwr.common.exit import ExitCode, flwr_exit
42
+ from flwr.common.heartbeat import HeartbeatSender, get_grpc_app_heartbeat_fn
41
43
  from flwr.common.logger import (
42
44
  log,
43
45
  mirror_output_to_queue,
@@ -54,12 +56,12 @@ from flwr.common.serde import (
54
56
  )
55
57
  from flwr.common.telemetry import EventType, event
56
58
  from flwr.common.typing import RunNotRunningException, RunStatus
57
- from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
58
- from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
59
- PullServerAppInputsRequest,
60
- PullServerAppInputsResponse,
61
- PushServerAppOutputsRequest,
59
+ from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
60
+ PullAppInputsRequest,
61
+ PullAppInputsResponse,
62
+ PushAppOutputsRequest,
62
63
  )
64
+ from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
63
65
  from flwr.server.grid.grpc_grid import GrpcGrid
64
66
  from flwr.server.run_serverapp import run as run_
65
67
 
@@ -106,24 +108,28 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
106
108
  certificates: Optional[bytes] = None,
107
109
  ) -> None:
108
110
  """Run Flower ServerApp process."""
109
- grid = GrpcGrid(
110
- serverappio_service_address=serverappio_api_address,
111
- root_certificates=certificates,
112
- )
113
-
114
111
  # Resolve directory where FABs are installed
115
112
  flwr_dir_ = get_flwr_dir(flwr_dir)
116
113
  log_uploader = None
117
114
  success = True
118
115
  hash_run_id = None
119
116
  run_status = None
117
+ heartbeat_sender = None
118
+ grid = None
119
+ context = None
120
120
  while True:
121
121
 
122
122
  try:
123
+ # Initialize the GrpcGrid
124
+ grid = GrpcGrid(
125
+ serverappio_service_address=serverappio_api_address,
126
+ root_certificates=certificates,
127
+ )
128
+
123
129
  # Pull ServerAppInputs from LinkState
124
- req = PullServerAppInputsRequest()
130
+ req = PullAppInputsRequest()
125
131
  log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
126
- res: PullServerAppInputsResponse = grid._stub.PullServerAppInputs(req)
132
+ res: PullAppInputsResponse = grid._stub.PullAppInputs(req)
127
133
  if not res.HasField("run"):
128
134
  sleep(3)
129
135
  run_status = None
@@ -182,6 +188,16 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
182
188
  event_details={"run-id-hash": hash_run_id},
183
189
  )
184
190
 
191
+ # Set up heartbeat sender
192
+ heartbeat_fn = get_grpc_app_heartbeat_fn(
193
+ grid._stub,
194
+ run.run_id,
195
+ failure_message="Heartbeat failed unexpectedly. The SuperLink could "
196
+ "not find the provided run ID, or the run status is invalid.",
197
+ )
198
+ heartbeat_sender = HeartbeatSender(heartbeat_fn)
199
+ heartbeat_sender.start()
200
+
185
201
  # Load and run the ServerApp with the Grid
186
202
  updated_context = run_(
187
203
  grid=grid,
@@ -193,10 +209,8 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
193
209
  # Send resulting context
194
210
  context_proto = context_to_proto(updated_context)
195
211
  log(DEBUG, "[flwr-serverapp] Will push ServerAppOutputs")
196
- out_req = PushServerAppOutputsRequest(
197
- run_id=run.run_id, context=context_proto
198
- )
199
- _ = grid._stub.PushServerAppOutputs(out_req)
212
+ out_req = PushAppOutputsRequest(run_id=run.run_id, context=context_proto)
213
+ _ = grid._stub.PushAppOutputs(out_req)
200
214
 
201
215
  run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
202
216
  except RunNotRunningException:
@@ -213,19 +227,33 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
213
227
  success = False
214
228
 
215
229
  finally:
230
+ # Stop heartbeat sender
231
+ if heartbeat_sender:
232
+ heartbeat_sender.stop()
233
+ heartbeat_sender = None
234
+
216
235
  # Stop log uploader for this run and upload final logs
217
236
  if log_uploader:
218
237
  stop_log_uploader(log_queue, log_uploader)
219
238
  log_uploader = None
220
239
 
221
240
  # Update run status
222
- if run_status:
241
+ if run_status and grid:
223
242
  run_status_proto = run_status_to_proto(run_status)
224
243
  grid._stub.UpdateRunStatus(
225
244
  UpdateRunStatusRequest(
226
245
  run_id=run.run_id, run_status=run_status_proto
227
246
  )
228
247
  )
248
+
249
+ # Close the Grpc connection
250
+ if grid:
251
+ grid.close()
252
+
253
+ # Clean up the Context
254
+ context = None
255
+ gc.collect()
256
+
229
257
  event(
230
258
  EventType.FLWR_SERVERAPP_RUN_LEAVE,
231
259
  event_details={"run-id-hash": hash_run_id, "success": success},
@@ -35,11 +35,16 @@ from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
35
35
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
36
36
  CreateNodeRequest,
37
37
  DeleteNodeRequest,
38
- PingRequest,
39
38
  PullMessagesRequest,
40
39
  PushMessagesRequest,
41
40
  )
42
41
  from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
42
+ from flwr.proto.heartbeat_pb2 import SendNodeHeartbeatRequest # pylint: disable=E0611
43
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
44
+ ConfirmMessageReceivedRequest,
45
+ PullObjectRequest,
46
+ PushObjectRequest,
47
+ )
43
48
  from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
44
49
 
45
50
  from ..grpc_rere.fleet_servicer import FleetServicer
@@ -81,8 +86,10 @@ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer, FleetService
81
86
  return _handle(request, context, CreateNodeRequest, self.CreateNode)
82
87
  if request.grpc_message_name == DeleteNodeRequest.__qualname__:
83
88
  return _handle(request, context, DeleteNodeRequest, self.DeleteNode)
84
- if request.grpc_message_name == PingRequest.__qualname__:
85
- return _handle(request, context, PingRequest, self.Ping)
89
+ if request.grpc_message_name == SendNodeHeartbeatRequest.__qualname__:
90
+ return _handle(
91
+ request, context, SendNodeHeartbeatRequest, self.SendNodeHeartbeat
92
+ )
86
93
  if request.grpc_message_name == GetRunRequest.__qualname__:
87
94
  return _handle(request, context, GetRunRequest, self.GetRun)
88
95
  if request.grpc_message_name == GetFabRequest.__qualname__:
@@ -91,4 +98,15 @@ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer, FleetService
91
98
  return _handle(request, context, PullMessagesRequest, self.PullMessages)
92
99
  if request.grpc_message_name == PushMessagesRequest.__qualname__:
93
100
  return _handle(request, context, PushMessagesRequest, self.PushMessages)
101
+ if request.grpc_message_name == PushObjectRequest.__qualname__:
102
+ return _handle(request, context, PushObjectRequest, self.PushObject)
103
+ if request.grpc_message_name == PullObjectRequest.__qualname__:
104
+ return _handle(request, context, PullObjectRequest, self.PullObject)
105
+ if request.grpc_message_name == ConfirmMessageReceivedRequest.__qualname__:
106
+ return _handle(
107
+ request,
108
+ context,
109
+ ConfirmMessageReceivedRequest,
110
+ self.ConfirmMessageReceived,
111
+ )
94
112
  raise ValueError(f"Invalid grpc_message_name: {request.grpc_message_name}")
@@ -20,6 +20,7 @@ from logging import DEBUG, INFO
20
20
  import grpc
21
21
  from google.protobuf.json_format import MessageToDict
22
22
 
23
+ from flwr.common.inflatable import UnexpectedObjectContentError
23
24
  from flwr.common.logger import log
24
25
  from flwr.common.typing import InvalidRunStatusException
25
26
  from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611
@@ -29,34 +30,53 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
29
30
  CreateNodeResponse,
30
31
  DeleteNodeRequest,
31
32
  DeleteNodeResponse,
32
- PingRequest,
33
- PingResponse,
34
33
  PullMessagesRequest,
35
34
  PullMessagesResponse,
36
35
  PushMessagesRequest,
37
36
  PushMessagesResponse,
38
37
  )
38
+ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
39
+ SendNodeHeartbeatRequest,
40
+ SendNodeHeartbeatResponse,
41
+ )
42
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
43
+ ConfirmMessageReceivedRequest,
44
+ ConfirmMessageReceivedResponse,
45
+ PullObjectRequest,
46
+ PullObjectResponse,
47
+ PushObjectRequest,
48
+ PushObjectResponse,
49
+ )
39
50
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
40
- from flwr.server.superlink.ffs.ffs_factory import FfsFactory
41
51
  from flwr.server.superlink.fleet.message_handler import message_handler
42
52
  from flwr.server.superlink.linkstate import LinkStateFactory
43
53
  from flwr.server.superlink.utils import abort_grpc_context
54
+ from flwr.supercore.ffs import FfsFactory
55
+ from flwr.supercore.object_store import ObjectStoreFactory
44
56
 
45
57
 
46
58
  class FleetServicer(fleet_pb2_grpc.FleetServicer):
47
59
  """Fleet API servicer."""
48
60
 
49
61
  def __init__(
50
- self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
62
+ self,
63
+ state_factory: LinkStateFactory,
64
+ ffs_factory: FfsFactory,
65
+ objectstore_factory: ObjectStoreFactory,
51
66
  ) -> None:
52
67
  self.state_factory = state_factory
53
68
  self.ffs_factory = ffs_factory
69
+ self.objectstore_factory = objectstore_factory
54
70
 
55
71
  def CreateNode(
56
72
  self, request: CreateNodeRequest, context: grpc.ServicerContext
57
73
  ) -> CreateNodeResponse:
58
74
  """."""
59
- log(INFO, "[Fleet.CreateNode] Request ping_interval=%s", request.ping_interval)
75
+ log(
76
+ INFO,
77
+ "[Fleet.CreateNode] Request heartbeat_interval=%s",
78
+ request.heartbeat_interval,
79
+ )
60
80
  log(DEBUG, "[Fleet.CreateNode] Request: %s", MessageToDict(request))
61
81
  response = message_handler.create_node(
62
82
  request=request,
@@ -77,10 +97,12 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
77
97
  state=self.state_factory.state(),
78
98
  )
79
99
 
80
- def Ping(self, request: PingRequest, context: grpc.ServicerContext) -> PingResponse:
100
+ def SendNodeHeartbeat(
101
+ self, request: SendNodeHeartbeatRequest, context: grpc.ServicerContext
102
+ ) -> SendNodeHeartbeatResponse:
81
103
  """."""
82
- log(DEBUG, "[Fleet.Ping] Request: %s", MessageToDict(request))
83
- return message_handler.ping(
104
+ log(DEBUG, "[Fleet.SendNodeHeartbeat] Request: %s", MessageToDict(request))
105
+ return message_handler.send_node_heartbeat(
84
106
  request=request,
85
107
  state=self.state_factory.state(),
86
108
  )
@@ -94,6 +116,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
94
116
  return message_handler.pull_messages(
95
117
  request=request,
96
118
  state=self.state_factory.state(),
119
+ store=self.objectstore_factory.store(),
97
120
  )
98
121
 
99
122
  def PushMessages(
@@ -113,6 +136,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
113
136
  res = message_handler.push_messages(
114
137
  request=request,
115
138
  state=self.state_factory.state(),
139
+ store=self.objectstore_factory.store(),
116
140
  )
117
141
  except InvalidRunStatusException as e:
118
142
  abort_grpc_context(e.message, context)
@@ -129,6 +153,7 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
129
153
  res = message_handler.get_run(
130
154
  request=request,
131
155
  state=self.state_factory.state(),
156
+ store=self.objectstore_factory.store(),
132
157
  )
133
158
  except InvalidRunStatusException as e:
134
159
  abort_grpc_context(e.message, context)
@@ -145,6 +170,75 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
145
170
  request=request,
146
171
  ffs=self.ffs_factory.ffs(),
147
172
  state=self.state_factory.state(),
173
+ store=self.objectstore_factory.store(),
174
+ )
175
+ except InvalidRunStatusException as e:
176
+ abort_grpc_context(e.message, context)
177
+
178
+ return res
179
+
180
+ def PushObject(
181
+ self, request: PushObjectRequest, context: grpc.ServicerContext
182
+ ) -> PushObjectResponse:
183
+ """Push an object to the ObjectStore."""
184
+ log(
185
+ DEBUG,
186
+ "[ServerAppIoServicer.PushObject] Push Object with object_id=%s",
187
+ request.object_id,
188
+ )
189
+
190
+ try:
191
+ # Insert in Store
192
+ res = message_handler.push_object(
193
+ request=request,
194
+ state=self.state_factory.state(),
195
+ store=self.objectstore_factory.store(),
196
+ )
197
+ except InvalidRunStatusException as e:
198
+ abort_grpc_context(e.message, context)
199
+ except UnexpectedObjectContentError as e:
200
+ # Object content is not valid
201
+ context.abort(grpc.StatusCode.FAILED_PRECONDITION, str(e))
202
+
203
+ return res
204
+
205
+ def PullObject(
206
+ self, request: PullObjectRequest, context: grpc.ServicerContext
207
+ ) -> PullObjectResponse:
208
+ """Pull an object from the ObjectStore."""
209
+ log(
210
+ DEBUG,
211
+ "[ServerAppIoServicer.PullObject] Pull Object with object_id=%s",
212
+ request.object_id,
213
+ )
214
+
215
+ try:
216
+ # Fetch from store
217
+ res = message_handler.pull_object(
218
+ request=request,
219
+ state=self.state_factory.state(),
220
+ store=self.objectstore_factory.store(),
221
+ )
222
+ except InvalidRunStatusException as e:
223
+ abort_grpc_context(e.message, context)
224
+
225
+ return res
226
+
227
+ def ConfirmMessageReceived(
228
+ self, request: ConfirmMessageReceivedRequest, context: grpc.ServicerContext
229
+ ) -> ConfirmMessageReceivedResponse:
230
+ """Confirm message received."""
231
+ log(
232
+ DEBUG,
233
+ "[Fleet.ConfirmMessageReceived] Message with ID '%s' has been received",
234
+ request.message_object_id,
235
+ )
236
+
237
+ try:
238
+ res = message_handler.confirm_message_received(
239
+ request=request,
240
+ state=self.state_factory.state(),
241
+ store=self.objectstore_factory.store(),
148
242
  )
149
243
  except InvalidRunStatusException as e:
150
244
  abort_grpc_context(e.message, context)
@@ -81,12 +81,9 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
81
81
  metadata sent by the node. Continue RPC call if node is authenticated, else,
82
82
  terminate RPC call by setting context to abort.
83
83
  """
84
- # Filter out non-Fleet service calls
84
+ # Only apply to Fleet service
85
85
  if not handler_call_details.method.startswith("/flwr.proto.Fleet/"):
86
- return _unary_unary_rpc_terminator(
87
- "This request should be sent to a different service.",
88
- grpc.StatusCode.FAILED_PRECONDITION,
89
- )
86
+ return continuation(handler_call_details)
90
87
 
91
88
  state = self.state_factory.state()
92
89
  metadata_dict = dict(handler_call_details.invocation_metadata)