flwr 1.18.0__py3-none-any.whl → 1.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. flwr/app/__init__.py +15 -0
  2. flwr/app/error.py +68 -0
  3. flwr/app/metadata.py +223 -0
  4. flwr/cli/build.py +82 -57
  5. flwr/cli/log.py +3 -3
  6. flwr/cli/login/login.py +3 -7
  7. flwr/cli/ls.py +15 -36
  8. flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
  9. flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
  10. flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
  11. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -17
  12. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  13. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  14. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  15. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  16. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  17. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  18. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  19. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  20. flwr/cli/run/run.py +10 -18
  21. flwr/cli/stop.py +2 -2
  22. flwr/cli/utils.py +31 -5
  23. flwr/client/__init__.py +2 -2
  24. flwr/client/client_app.py +1 -1
  25. flwr/client/clientapp/__init__.py +0 -7
  26. flwr/client/grpc_adapter_client/connection.py +4 -4
  27. flwr/client/grpc_rere_client/connection.py +130 -60
  28. flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
  29. flwr/client/message_handler/message_handler.py +1 -1
  30. flwr/client/mod/comms_mods.py +36 -17
  31. flwr/client/rest_client/connection.py +173 -67
  32. flwr/clientapp/__init__.py +15 -0
  33. flwr/common/__init__.py +2 -2
  34. flwr/common/auth_plugin/__init__.py +2 -0
  35. flwr/common/auth_plugin/auth_plugin.py +29 -3
  36. flwr/common/constant.py +36 -7
  37. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  38. flwr/common/exit_handlers.py +30 -0
  39. flwr/common/heartbeat.py +165 -0
  40. flwr/common/inflatable.py +290 -0
  41. flwr/common/inflatable_grpc_utils.py +99 -0
  42. flwr/common/inflatable_rest_utils.py +99 -0
  43. flwr/common/inflatable_utils.py +341 -0
  44. flwr/common/message.py +110 -242
  45. flwr/common/record/__init__.py +2 -1
  46. flwr/common/record/array.py +323 -0
  47. flwr/common/record/arrayrecord.py +103 -225
  48. flwr/common/record/configrecord.py +59 -4
  49. flwr/common/record/conversion_utils.py +1 -1
  50. flwr/common/record/metricrecord.py +55 -4
  51. flwr/common/record/recorddict.py +69 -1
  52. flwr/common/recorddict_compat.py +2 -2
  53. flwr/common/retry_invoker.py +5 -1
  54. flwr/common/serde.py +59 -183
  55. flwr/common/serde_utils.py +175 -0
  56. flwr/common/typing.py +5 -3
  57. flwr/compat/__init__.py +15 -0
  58. flwr/compat/client/__init__.py +15 -0
  59. flwr/{client → compat/client}/app.py +19 -159
  60. flwr/compat/common/__init__.py +15 -0
  61. flwr/compat/server/__init__.py +15 -0
  62. flwr/compat/server/app.py +174 -0
  63. flwr/compat/simulation/__init__.py +15 -0
  64. flwr/proto/fleet_pb2.py +32 -27
  65. flwr/proto/fleet_pb2.pyi +49 -35
  66. flwr/proto/fleet_pb2_grpc.py +117 -13
  67. flwr/proto/fleet_pb2_grpc.pyi +47 -6
  68. flwr/proto/heartbeat_pb2.py +33 -0
  69. flwr/proto/heartbeat_pb2.pyi +66 -0
  70. flwr/proto/heartbeat_pb2_grpc.py +4 -0
  71. flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
  72. flwr/proto/message_pb2.py +28 -11
  73. flwr/proto/message_pb2.pyi +125 -0
  74. flwr/proto/recorddict_pb2.py +16 -28
  75. flwr/proto/recorddict_pb2.pyi +46 -64
  76. flwr/proto/run_pb2.py +24 -32
  77. flwr/proto/run_pb2.pyi +4 -52
  78. flwr/proto/serverappio_pb2.py +32 -23
  79. flwr/proto/serverappio_pb2.pyi +45 -3
  80. flwr/proto/serverappio_pb2_grpc.py +138 -34
  81. flwr/proto/serverappio_pb2_grpc.pyi +54 -13
  82. flwr/proto/simulationio_pb2.py +12 -11
  83. flwr/proto/simulationio_pb2_grpc.py +35 -0
  84. flwr/proto/simulationio_pb2_grpc.pyi +14 -0
  85. flwr/server/__init__.py +1 -1
  86. flwr/server/app.py +68 -186
  87. flwr/server/compat/app_utils.py +50 -28
  88. flwr/server/fleet_event_log_interceptor.py +2 -2
  89. flwr/server/grid/grpc_grid.py +104 -34
  90. flwr/server/grid/inmemory_grid.py +5 -4
  91. flwr/server/serverapp/app.py +18 -0
  92. flwr/server/superlink/ffs/__init__.py +2 -0
  93. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +13 -3
  94. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +101 -7
  95. flwr/server/superlink/fleet/message_handler/message_handler.py +135 -18
  96. flwr/server/superlink/fleet/rest_rere/rest_api.py +72 -11
  97. flwr/server/superlink/fleet/vce/vce_api.py +6 -3
  98. flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
  99. flwr/server/superlink/linkstate/linkstate.py +53 -20
  100. flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
  101. flwr/server/superlink/linkstate/utils.py +33 -29
  102. flwr/server/superlink/serverappio/serverappio_grpc.py +3 -0
  103. flwr/server/superlink/serverappio/serverappio_servicer.py +211 -57
  104. flwr/server/superlink/simulation/simulationio_servicer.py +25 -1
  105. flwr/server/superlink/utils.py +44 -2
  106. flwr/server/utils/validator.py +2 -2
  107. flwr/serverapp/__init__.py +15 -0
  108. flwr/simulation/app.py +17 -0
  109. flwr/supercore/__init__.py +15 -0
  110. flwr/supercore/object_store/__init__.py +24 -0
  111. flwr/supercore/object_store/in_memory_object_store.py +229 -0
  112. flwr/supercore/object_store/object_store.py +192 -0
  113. flwr/supercore/object_store/object_store_factory.py +44 -0
  114. flwr/superexec/deployment.py +6 -2
  115. flwr/superexec/exec_event_log_interceptor.py +4 -4
  116. flwr/superexec/exec_grpc.py +7 -3
  117. flwr/superexec/exec_servicer.py +125 -23
  118. flwr/superexec/exec_user_auth_interceptor.py +37 -8
  119. flwr/superexec/executor.py +4 -0
  120. flwr/superexec/simulation.py +7 -1
  121. flwr/superlink/__init__.py +15 -0
  122. flwr/{client/supernode → supernode}/__init__.py +0 -7
  123. flwr/{client/nodestate/nodestate.py → supernode/cli/__init__.py} +7 -14
  124. flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -12
  125. flwr/supernode/cli/flwr_clientapp.py +81 -0
  126. flwr/supernode/nodestate/in_memory_nodestate.py +190 -0
  127. flwr/supernode/nodestate/nodestate.py +212 -0
  128. flwr/supernode/runtime/__init__.py +15 -0
  129. flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +25 -56
  130. flwr/supernode/servicer/__init__.py +15 -0
  131. flwr/supernode/servicer/clientappio/__init__.py +24 -0
  132. flwr/supernode/start_client_internal.py +491 -0
  133. {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/METADATA +5 -4
  134. {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/RECORD +141 -108
  135. {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/WHEEL +1 -1
  136. {flwr-1.18.0.dist-info → flwr-1.19.0.dist-info}/entry_points.txt +2 -2
  137. flwr/client/heartbeat.py +0 -74
  138. flwr/client/nodestate/in_memory_nodestate.py +0 -38
  139. /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
  140. /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
  141. /flwr/{client → supernode}/nodestate/__init__.py +0 -0
  142. /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
  143. /flwr/{client/clientapp → supernode/servicer/clientappio}/clientappio_servicer.py +0 -0
@@ -15,47 +15,61 @@
15
15
  """Contextmanager for a gRPC request-response channel to the Flower server."""
16
16
 
17
17
 
18
- import random
19
- import threading
20
18
  from collections.abc import Iterator, Sequence
21
19
  from contextlib import contextmanager
22
20
  from copy import copy
23
- from logging import ERROR
21
+ from logging import DEBUG, ERROR
24
22
  from pathlib import Path
25
23
  from typing import Callable, Optional, Union, cast
26
24
 
27
25
  import grpc
28
26
  from cryptography.hazmat.primitives.asymmetric import ec
29
27
 
30
- from flwr.client.heartbeat import start_ping_loop
28
+ from flwr.app.metadata import Metadata
31
29
  from flwr.client.message_handler.message_handler import validate_out_message
32
30
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
33
- from flwr.common.constant import (
34
- PING_BASE_MULTIPLIER,
35
- PING_CALL_TIMEOUT,
36
- PING_DEFAULT_INTERVAL,
37
- PING_RANDOM_RANGE,
38
- )
31
+ from flwr.common.constant import HEARTBEAT_CALL_TIMEOUT, HEARTBEAT_DEFAULT_INTERVAL
39
32
  from flwr.common.grpc import create_channel, on_channel_state_change
33
+ from flwr.common.heartbeat import HeartbeatSender
34
+ from flwr.common.inflatable import (
35
+ get_all_nested_objects,
36
+ get_object_tree,
37
+ no_object_id_recompute,
38
+ )
39
+ from flwr.common.inflatable_grpc_utils import (
40
+ make_pull_object_fn_grpc,
41
+ make_push_object_fn_grpc,
42
+ )
43
+ from flwr.common.inflatable_utils import (
44
+ inflate_object_from_contents,
45
+ pull_objects,
46
+ push_objects,
47
+ )
40
48
  from flwr.common.logger import log
41
- from flwr.common.message import Message, Metadata
42
- from flwr.common.retry_invoker import RetryInvoker
49
+ from flwr.common.message import Message, remove_content_from_message
50
+ from flwr.common.retry_invoker import RetryInvoker, _wrap_stub
43
51
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
44
52
  generate_key_pairs,
45
53
  )
46
- from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
54
+ from flwr.common.serde import message_to_proto, run_from_proto
47
55
  from flwr.common.typing import Fab, Run, RunNotRunningException
48
56
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
49
57
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
50
58
  CreateNodeRequest,
51
59
  DeleteNodeRequest,
52
- PingRequest,
53
- PingResponse,
54
60
  PullMessagesRequest,
55
61
  PullMessagesResponse,
56
62
  PushMessagesRequest,
63
+ PushMessagesResponse,
57
64
  )
58
65
  from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
66
+ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
67
+ SendNodeHeartbeatRequest,
68
+ SendNodeHeartbeatResponse,
69
+ )
70
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
71
+ ConfirmMessageReceivedRequest,
72
+ )
59
73
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
60
74
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
61
75
 
@@ -78,10 +92,10 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
78
92
  tuple[
79
93
  Callable[[], Optional[Message]],
80
94
  Callable[[Message], None],
81
- Optional[Callable[[], Optional[int]]],
82
- Optional[Callable[[], None]],
83
- Optional[Callable[[int], Run]],
84
- Optional[Callable[[str, int], Fab]],
95
+ Callable[[], Optional[int]],
96
+ Callable[[], None],
97
+ Callable[[int], Run],
98
+ Callable[[str, int], Fab],
85
99
  ]
86
100
  ]:
87
101
  """Primitives for request/response-based interaction with a server.
@@ -151,8 +165,6 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
151
165
  stub = adapter_cls(channel)
152
166
  metadata: Optional[Metadata] = None
153
167
  node: Optional[Node] = None
154
- ping_thread: Optional[threading.Thread] = None
155
- ping_stop_event = threading.Event()
156
168
 
157
169
  def _should_giveup_fn(e: Exception) -> bool:
158
170
  if e.code() == grpc.StatusCode.PERMISSION_DENIED: # type: ignore
@@ -165,46 +177,58 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
165
177
  # If the status code is PERMISSION_DENIED, additionally raise RunNotRunningException
166
178
  retry_invoker.should_giveup = _should_giveup_fn
167
179
 
180
+ # Wrap stub
181
+ _wrap_stub(stub, retry_invoker)
168
182
  ###########################################################################
169
- # ping/create_node/delete_node/receive/send/get_run functions
183
+ # send_node_heartbeat/create_node/delete_node/receive/send/get_run functions
170
184
  ###########################################################################
171
185
 
172
- def ping() -> None:
186
+ def send_node_heartbeat() -> bool:
173
187
  # Get Node
174
188
  if node is None:
175
189
  log(ERROR, "Node instance missing")
176
- return
190
+ return False
177
191
 
178
- # Construct the ping request
179
- req = PingRequest(node=node, ping_interval=PING_DEFAULT_INTERVAL)
192
+ # Construct the heartbeat request
193
+ req = SendNodeHeartbeatRequest(
194
+ node=node, heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
195
+ )
180
196
 
181
197
  # Call FleetAPI
182
- res: PingResponse = stub.Ping(req, timeout=PING_CALL_TIMEOUT)
198
+ try:
199
+ res: SendNodeHeartbeatResponse = stub.SendNodeHeartbeat(
200
+ req, timeout=HEARTBEAT_CALL_TIMEOUT
201
+ )
202
+ except grpc.RpcError as e:
203
+ status_code = e.code()
204
+ if status_code == grpc.StatusCode.UNAVAILABLE:
205
+ return False
206
+ if status_code == grpc.StatusCode.DEADLINE_EXCEEDED:
207
+ return False
208
+ raise
183
209
 
184
210
  # Check if success
185
211
  if not res.success:
186
- raise RuntimeError("Ping failed unexpectedly.")
212
+ raise RuntimeError(
213
+ "Heartbeat failed unexpectedly. The SuperLink does not "
214
+ "recognize this SuperNode."
215
+ )
216
+ return True
187
217
 
188
- # Wait
189
- rd = random.uniform(*PING_RANDOM_RANGE)
190
- next_interval: float = PING_DEFAULT_INTERVAL - PING_CALL_TIMEOUT
191
- next_interval *= PING_BASE_MULTIPLIER + rd
192
- if not ping_stop_event.is_set():
193
- ping_stop_event.wait(next_interval)
218
+ heartbeat_sender = HeartbeatSender(send_node_heartbeat)
194
219
 
195
220
  def create_node() -> Optional[int]:
196
221
  """Set create_node."""
197
222
  # Call FleetAPI
198
- create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
199
- create_node_response = retry_invoker.invoke(
200
- stub.CreateNode,
201
- request=create_node_request,
223
+ create_node_request = CreateNodeRequest(
224
+ heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
202
225
  )
226
+ create_node_response = stub.CreateNode(request=create_node_request)
203
227
 
204
- # Remember the node and the ping-loop thread
205
- nonlocal node, ping_thread
228
+ # Remember the node and start the heartbeat sender
229
+ nonlocal node
206
230
  node = cast(Node, create_node_response.node)
207
- ping_thread = start_ping_loop(ping, ping_stop_event)
231
+ heartbeat_sender.start()
208
232
  return node.node_id
209
233
 
210
234
  def delete_node() -> None:
@@ -215,12 +239,12 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
215
239
  log(ERROR, "Node instance missing")
216
240
  return
217
241
 
218
- # Stop the ping-loop thread
219
- ping_stop_event.set()
242
+ # Stop the heartbeat sender
243
+ heartbeat_sender.stop()
220
244
 
221
245
  # Call FleetAPI
222
246
  delete_node_request = DeleteNodeRequest(node=node)
223
- retry_invoker.invoke(stub.DeleteNode, request=delete_node_request)
247
+ stub.DeleteNode(request=delete_node_request)
224
248
 
225
249
  # Cleanup
226
250
  node = None
@@ -234,9 +258,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
234
258
 
235
259
  # Request instructions (message) from server
236
260
  request = PullMessagesRequest(node=node)
237
- response: PullMessagesResponse = retry_invoker.invoke(
238
- stub.PullMessages, request=request
239
- )
261
+ response: PullMessagesResponse = stub.PullMessages(request=request)
240
262
 
241
263
  # Get the current Messages
242
264
  message_proto = (
@@ -250,7 +272,33 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
250
272
  message_proto = None
251
273
 
252
274
  # Construct the Message
253
- in_message = message_from_proto(message_proto) if message_proto else None
275
+ in_message: Optional[Message] = None
276
+
277
+ if message_proto:
278
+ msg_id = message_proto.metadata.message_id
279
+ run_id = message_proto.metadata.run_id
280
+ all_object_contents = pull_objects(
281
+ list(response.objects_to_pull[msg_id].object_ids) + [msg_id],
282
+ pull_object_fn=make_pull_object_fn_grpc(
283
+ pull_object_grpc=stub.PullObject,
284
+ node=node,
285
+ run_id=run_id,
286
+ ),
287
+ )
288
+
289
+ # Confirm that the message has been received
290
+ stub.ConfirmMessageReceived(
291
+ ConfirmMessageReceivedRequest(
292
+ node=node, run_id=run_id, message_object_id=msg_id
293
+ )
294
+ )
295
+
296
+ in_message = cast(
297
+ Message, inflate_object_from_contents(msg_id, all_object_contents)
298
+ )
299
+ # The deflated message doesn't contain the message_id (its own object_id)
300
+ # Inject
301
+ in_message.metadata.__dict__["_message_id"] = msg_id
254
302
 
255
303
  # Remember `metadata` of the in message
256
304
  nonlocal metadata
@@ -272,15 +320,43 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
272
320
  log(ERROR, "No current message")
273
321
  return
274
322
 
323
+ # Set message_id
324
+ message.metadata.__dict__["_message_id"] = message.object_id
275
325
  # Validate out message
276
326
  if not validate_out_message(message, metadata):
277
327
  log(ERROR, "Invalid out message")
278
328
  return
279
329
 
280
- # Serialize Message
281
- message_proto = message_to_proto(message=message)
282
- request = PushMessagesRequest(node=node, messages_list=[message_proto])
283
- _ = retry_invoker.invoke(stub.PushMessages, request)
330
+ with no_object_id_recompute():
331
+ # Get all nested objects
332
+ all_objects = get_all_nested_objects(message)
333
+ object_tree = get_object_tree(message)
334
+
335
+ # Serialize Message
336
+ message_proto = message_to_proto(
337
+ message=remove_content_from_message(message)
338
+ )
339
+ request = PushMessagesRequest(
340
+ node=node,
341
+ messages_list=[message_proto],
342
+ message_object_trees=[object_tree],
343
+ )
344
+ response: PushMessagesResponse = stub.PushMessages(request=request)
345
+
346
+ if response.objects_to_push:
347
+ objs_to_push = set(
348
+ response.objects_to_push[message.object_id].object_ids
349
+ )
350
+ push_objects(
351
+ all_objects,
352
+ push_object_fn=make_push_object_fn_grpc(
353
+ push_object_grpc=stub.PushObject,
354
+ node=node,
355
+ run_id=message.metadata.run_id,
356
+ ),
357
+ object_ids_to_push=objs_to_push,
358
+ )
359
+ log(DEBUG, "Pushed %s objects to servicer.", len(objs_to_push))
284
360
 
285
361
  # Cleanup
286
362
  metadata = None
@@ -288,10 +364,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
288
364
  def get_run(run_id: int) -> Run:
289
365
  # Call FleetAPI
290
366
  get_run_request = GetRunRequest(node=node, run_id=run_id)
291
- get_run_response: GetRunResponse = retry_invoker.invoke(
292
- stub.GetRun,
293
- request=get_run_request,
294
- )
367
+ get_run_response: GetRunResponse = stub.GetRun(request=get_run_request)
295
368
 
296
369
  # Return fab_id and fab_version
297
370
  return run_from_proto(get_run_response.run)
@@ -299,10 +372,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
299
372
  def get_fab(fab_hash: str, run_id: int) -> Fab:
300
373
  # Call FleetAPI
301
374
  get_fab_request = GetFabRequest(node=node, hash_str=fab_hash, run_id=run_id)
302
- get_fab_response: GetFabResponse = retry_invoker.invoke(
303
- stub.GetFab,
304
- request=get_fab_request,
305
- )
375
+ get_fab_response: GetFabResponse = stub.GetFab(request=get_fab_request)
306
376
 
307
377
  return Fab(get_fab_response.fab.hash_str, get_fab_response.fab.content)
308
378
 
@@ -38,8 +38,6 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
38
38
  CreateNodeResponse,
39
39
  DeleteNodeRequest,
40
40
  DeleteNodeResponse,
41
- PingRequest,
42
- PingResponse,
43
41
  PullMessagesRequest,
44
42
  PullMessagesResponse,
45
43
  PushMessagesRequest,
@@ -47,6 +45,18 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
47
45
  )
48
46
  from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
49
47
  from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub
48
+ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
49
+ SendNodeHeartbeatRequest,
50
+ SendNodeHeartbeatResponse,
51
+ )
52
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
53
+ ConfirmMessageReceivedRequest,
54
+ ConfirmMessageReceivedResponse,
55
+ PullObjectRequest,
56
+ PullObjectResponse,
57
+ PushObjectRequest,
58
+ PushObjectResponse,
59
+ )
50
60
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
51
61
 
52
62
  T = TypeVar("T", bound=GrpcMessage)
@@ -120,11 +130,11 @@ class GrpcAdapter:
120
130
  """."""
121
131
  return self._send_and_receive(request, DeleteNodeResponse, **kwargs)
122
132
 
123
- def Ping( # pylint: disable=C0103
124
- self, request: PingRequest, **kwargs: Any
125
- ) -> PingResponse:
133
+ def SendNodeHeartbeat( # pylint: disable=C0103
134
+ self, request: SendNodeHeartbeatRequest, **kwargs: Any
135
+ ) -> SendNodeHeartbeatResponse:
126
136
  """."""
127
- return self._send_and_receive(request, PingResponse, **kwargs)
137
+ return self._send_and_receive(request, SendNodeHeartbeatResponse, **kwargs)
128
138
 
129
139
  def PullMessages( # pylint: disable=C0103
130
140
  self, request: PullMessagesRequest, **kwargs: Any
@@ -149,3 +159,21 @@ class GrpcAdapter:
149
159
  ) -> GetFabResponse:
150
160
  """."""
151
161
  return self._send_and_receive(request, GetFabResponse, **kwargs)
162
+
163
+ def PushObject( # pylint: disable=C0103
164
+ self, request: PushObjectRequest, **kwargs: Any
165
+ ) -> PushObjectResponse:
166
+ """."""
167
+ return self._send_and_receive(request, PushObjectResponse, **kwargs)
168
+
169
+ def PullObject( # pylint: disable=C0103
170
+ self, request: PullObjectRequest, **kwargs: Any
171
+ ) -> PullObjectResponse:
172
+ """."""
173
+ return self._send_and_receive(request, PullObjectResponse, **kwargs)
174
+
175
+ def ConfirmMessageReceived( # pylint: disable=C0103
176
+ self, request: ConfirmMessageReceivedRequest, **kwargs: Any
177
+ ) -> ConfirmMessageReceivedResponse:
178
+ """."""
179
+ return self._send_and_receive(request, ConfirmMessageReceivedResponse, **kwargs)
@@ -164,7 +164,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
164
164
  in_meta = in_message_metadata
165
165
  if ( # pylint: disable-next=too-many-boolean-expressions
166
166
  out_meta.run_id == in_meta.run_id
167
- and out_meta.message_id == "" # This will be generated by the server
167
+ and out_meta.message_id == out_message.object_id # Should match the object id
168
168
  and out_meta.src_node_id == in_meta.dst_node_id
169
169
  and out_meta.dst_node_id == in_meta.src_node_id
170
170
  and out_meta.reply_to_message_id == in_meta.message_id
@@ -32,14 +32,17 @@ def message_size_mod(
32
32
 
33
33
  This mod logs the size in bytes of the message being transmited.
34
34
  """
35
- message_size_in_bytes = 0
35
+ # Log the size of the incoming message in bytes
36
+ total_bytes = sum(record.count_bytes() for record in msg.content.values())
37
+ log(INFO, "Incoming message size: %i bytes", total_bytes)
36
38
 
37
- for record in msg.content.values():
38
- message_size_in_bytes += record.count_bytes()
39
+ # Call the next layer
40
+ msg = call_next(msg, ctxt)
39
41
 
40
- log(INFO, "Message size: %i bytes", message_size_in_bytes)
41
-
42
- return call_next(msg, ctxt)
42
+ # Log the size of the outgoing message in bytes
43
+ total_bytes = sum(record.count_bytes() for record in msg.content.values())
44
+ log(INFO, "Outgoing message size: %i bytes", total_bytes)
45
+ return msg
43
46
 
44
47
 
45
48
  def arrays_size_mod(
@@ -50,25 +53,41 @@ def arrays_size_mod(
50
53
  This mod logs the number of array elements transmitted in ``ArrayRecord`` objects
51
54
  of the message as well as their sizes in bytes.
52
55
  """
53
- model_size_stats = {}
54
- arrays_size_in_bytes = 0
56
+ # Log the ArrayRecord size statistics and the total size in the incoming message
57
+ array_record_size_stats = _get_array_record_size_stats(msg)
58
+ total_bytes = sum(stat["bytes"] for stat in array_record_size_stats.values())
59
+ if array_record_size_stats:
60
+ log(INFO, "Incoming `ArrayRecord` size statistics:")
61
+ log(INFO, array_record_size_stats)
62
+ log(INFO, "Total array elements received: %i bytes", total_bytes)
63
+
64
+ msg = call_next(msg, ctxt)
65
+
66
+ # Log the ArrayRecord size statistics and the total size in the outgoing message
67
+ array_record_size_stats = _get_array_record_size_stats(msg)
68
+ total_bytes = sum(stat["bytes"] for stat in array_record_size_stats.values())
69
+ if array_record_size_stats:
70
+ log(INFO, "Outgoing `ArrayRecord` size statistics:")
71
+ log(INFO, array_record_size_stats)
72
+ log(INFO, "Total array elements sent: %i bytes", total_bytes)
73
+ return msg
74
+
75
+
76
+ def _get_array_record_size_stats(
77
+ msg: Message,
78
+ ) -> dict[str, dict[str, int]]:
79
+ """Get `ArrayRecord` size statistics from the message."""
80
+ array_record_size_stats = {}
55
81
  for record_name, arr_record in msg.content.array_records.items():
56
82
  arr_record_bytes = arr_record.count_bytes()
57
- arrays_size_in_bytes += arr_record_bytes
58
83
  element_count = 0
59
84
  for array in arr_record.values():
60
85
  element_count += (
61
86
  int(np.prod(array.shape)) if array.shape else array.numpy().size
62
87
  )
63
88
 
64
- model_size_stats[f"{record_name}"] = {
89
+ array_record_size_stats[record_name] = {
65
90
  "elements": element_count,
66
91
  "bytes": arr_record_bytes,
67
92
  }
68
-
69
- if model_size_stats:
70
- log(INFO, model_size_stats)
71
-
72
- log(INFO, "Total array elements transmitted: %i bytes", arrays_size_in_bytes)
73
-
74
- return call_next(msg, ctxt)
93
+ return array_record_size_stats