flwr 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. flwr/app/__init__.py +15 -0
  2. flwr/app/error.py +68 -0
  3. flwr/app/metadata.py +223 -0
  4. flwr/cli/build.py +94 -59
  5. flwr/cli/log.py +3 -3
  6. flwr/cli/login/login.py +3 -7
  7. flwr/cli/ls.py +15 -36
  8. flwr/cli/new/new.py +12 -4
  9. flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
  10. flwr/cli/new/templates/app/README.md.tpl +5 -0
  11. flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
  12. flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
  13. flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
  14. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +25 -17
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
  16. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
  17. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
  18. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
  19. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
  20. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
  21. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
  22. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
  23. flwr/cli/run/run.py +48 -49
  24. flwr/cli/stop.py +2 -2
  25. flwr/cli/utils.py +38 -5
  26. flwr/client/__init__.py +2 -2
  27. flwr/client/client_app.py +1 -1
  28. flwr/client/clientapp/__init__.py +0 -7
  29. flwr/client/grpc_adapter_client/connection.py +15 -8
  30. flwr/client/grpc_rere_client/connection.py +142 -97
  31. flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
  32. flwr/client/message_handler/message_handler.py +1 -1
  33. flwr/client/mod/comms_mods.py +36 -17
  34. flwr/client/rest_client/connection.py +176 -103
  35. flwr/clientapp/__init__.py +15 -0
  36. flwr/common/__init__.py +2 -2
  37. flwr/common/auth_plugin/__init__.py +2 -0
  38. flwr/common/auth_plugin/auth_plugin.py +29 -3
  39. flwr/common/constant.py +39 -8
  40. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  41. flwr/common/exit/exit_code.py +16 -1
  42. flwr/common/exit_handlers.py +30 -0
  43. flwr/common/grpc.py +12 -1
  44. flwr/common/heartbeat.py +165 -0
  45. flwr/common/inflatable.py +290 -0
  46. flwr/common/inflatable_protobuf_utils.py +141 -0
  47. flwr/common/inflatable_utils.py +508 -0
  48. flwr/common/message.py +110 -242
  49. flwr/common/record/__init__.py +2 -1
  50. flwr/common/record/array.py +402 -0
  51. flwr/common/record/arraychunk.py +59 -0
  52. flwr/common/record/arrayrecord.py +103 -225
  53. flwr/common/record/configrecord.py +59 -4
  54. flwr/common/record/conversion_utils.py +1 -1
  55. flwr/common/record/metricrecord.py +55 -4
  56. flwr/common/record/recorddict.py +69 -1
  57. flwr/common/recorddict_compat.py +2 -2
  58. flwr/common/retry_invoker.py +5 -1
  59. flwr/common/serde.py +59 -211
  60. flwr/common/serde_utils.py +175 -0
  61. flwr/common/typing.py +5 -3
  62. flwr/compat/__init__.py +15 -0
  63. flwr/compat/client/__init__.py +15 -0
  64. flwr/{client → compat/client}/app.py +28 -185
  65. flwr/compat/common/__init__.py +15 -0
  66. flwr/compat/server/__init__.py +15 -0
  67. flwr/compat/server/app.py +174 -0
  68. flwr/compat/simulation/__init__.py +15 -0
  69. flwr/proto/appio_pb2.py +43 -0
  70. flwr/proto/appio_pb2.pyi +151 -0
  71. flwr/proto/appio_pb2_grpc.py +4 -0
  72. flwr/proto/appio_pb2_grpc.pyi +4 -0
  73. flwr/proto/clientappio_pb2.py +12 -19
  74. flwr/proto/clientappio_pb2.pyi +23 -101
  75. flwr/proto/clientappio_pb2_grpc.py +269 -28
  76. flwr/proto/clientappio_pb2_grpc.pyi +114 -20
  77. flwr/proto/fleet_pb2.py +24 -27
  78. flwr/proto/fleet_pb2.pyi +19 -35
  79. flwr/proto/fleet_pb2_grpc.py +117 -13
  80. flwr/proto/fleet_pb2_grpc.pyi +47 -6
  81. flwr/proto/heartbeat_pb2.py +33 -0
  82. flwr/proto/heartbeat_pb2.pyi +66 -0
  83. flwr/proto/heartbeat_pb2_grpc.py +4 -0
  84. flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
  85. flwr/proto/message_pb2.py +28 -11
  86. flwr/proto/message_pb2.pyi +125 -0
  87. flwr/proto/recorddict_pb2.py +16 -28
  88. flwr/proto/recorddict_pb2.pyi +46 -64
  89. flwr/proto/run_pb2.py +24 -32
  90. flwr/proto/run_pb2.pyi +4 -52
  91. flwr/proto/serverappio_pb2.py +9 -23
  92. flwr/proto/serverappio_pb2.pyi +0 -110
  93. flwr/proto/serverappio_pb2_grpc.py +177 -72
  94. flwr/proto/serverappio_pb2_grpc.pyi +75 -33
  95. flwr/proto/simulationio_pb2.py +12 -11
  96. flwr/proto/simulationio_pb2_grpc.py +35 -0
  97. flwr/proto/simulationio_pb2_grpc.pyi +14 -0
  98. flwr/server/__init__.py +1 -1
  99. flwr/server/app.py +69 -187
  100. flwr/server/compat/app_utils.py +50 -28
  101. flwr/server/fleet_event_log_interceptor.py +6 -2
  102. flwr/server/grid/grpc_grid.py +148 -41
  103. flwr/server/grid/inmemory_grid.py +5 -4
  104. flwr/server/serverapp/app.py +45 -17
  105. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +21 -3
  106. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
  107. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
  108. flwr/server/superlink/fleet/message_handler/message_handler.py +130 -19
  109. flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -13
  110. flwr/server/superlink/fleet/vce/vce_api.py +6 -3
  111. flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
  112. flwr/server/superlink/linkstate/linkstate.py +53 -20
  113. flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
  114. flwr/server/superlink/linkstate/utils.py +33 -29
  115. flwr/server/superlink/serverappio/serverappio_grpc.py +4 -1
  116. flwr/server/superlink/serverappio/serverappio_servicer.py +230 -84
  117. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  118. flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
  119. flwr/server/superlink/utils.py +9 -2
  120. flwr/server/utils/validator.py +2 -2
  121. flwr/serverapp/__init__.py +15 -0
  122. flwr/simulation/app.py +25 -0
  123. flwr/simulation/run_simulation.py +17 -0
  124. flwr/supercore/__init__.py +15 -0
  125. flwr/{server/superlink → supercore}/ffs/__init__.py +2 -0
  126. flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
  127. flwr/supercore/grpc_health/__init__.py +22 -0
  128. flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
  129. flwr/supercore/license_plugin/__init__.py +22 -0
  130. flwr/supercore/license_plugin/license_plugin.py +26 -0
  131. flwr/supercore/object_store/__init__.py +24 -0
  132. flwr/supercore/object_store/in_memory_object_store.py +229 -0
  133. flwr/supercore/object_store/object_store.py +170 -0
  134. flwr/supercore/object_store/object_store_factory.py +44 -0
  135. flwr/supercore/object_store/utils.py +43 -0
  136. flwr/supercore/scheduler/__init__.py +22 -0
  137. flwr/supercore/scheduler/plugin.py +71 -0
  138. flwr/{client/nodestate/nodestate.py → supercore/utils.py} +14 -13
  139. flwr/superexec/deployment.py +7 -4
  140. flwr/superexec/exec_event_log_interceptor.py +8 -4
  141. flwr/superexec/exec_grpc.py +25 -5
  142. flwr/superexec/exec_license_interceptor.py +82 -0
  143. flwr/superexec/exec_servicer.py +135 -24
  144. flwr/superexec/exec_user_auth_interceptor.py +45 -8
  145. flwr/superexec/executor.py +5 -1
  146. flwr/superexec/simulation.py +8 -3
  147. flwr/superlink/__init__.py +15 -0
  148. flwr/{client/supernode → supernode}/__init__.py +0 -7
  149. flwr/supernode/cli/__init__.py +24 -0
  150. flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -19
  151. flwr/supernode/cli/flwr_clientapp.py +88 -0
  152. flwr/supernode/nodestate/in_memory_nodestate.py +199 -0
  153. flwr/supernode/nodestate/nodestate.py +227 -0
  154. flwr/supernode/runtime/__init__.py +15 -0
  155. flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +135 -89
  156. flwr/supernode/scheduler/__init__.py +22 -0
  157. flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
  158. flwr/supernode/servicer/__init__.py +15 -0
  159. flwr/supernode/servicer/clientappio/__init__.py +22 -0
  160. flwr/supernode/servicer/clientappio/clientappio_servicer.py +303 -0
  161. flwr/supernode/start_client_internal.py +589 -0
  162. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/METADATA +6 -4
  163. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/RECORD +171 -123
  164. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +1 -1
  165. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +2 -2
  166. flwr/client/clientapp/clientappio_servicer.py +0 -244
  167. flwr/client/heartbeat.py +0 -74
  168. flwr/client/nodestate/in_memory_nodestate.py +0 -38
  169. /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
  170. /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
  171. /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
  172. /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
  173. /flwr/{client → supernode}/nodestate/__init__.py +0 -0
  174. /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
flwr/cli/run/run.py CHANGED
@@ -24,9 +24,8 @@ from typing import Annotated, Any, Optional
24
24
  import typer
25
25
  from rich.console import Console
26
26
 
27
- from flwr.cli.build import build
27
+ from flwr.cli.build import build_fab, get_fab_filename
28
28
  from flwr.cli.config_utils import (
29
- get_fab_metadata,
30
29
  load_and_validate,
31
30
  process_loaded_project_config,
32
31
  validate_federation_in_project_config,
@@ -34,6 +33,7 @@ from flwr.cli.config_utils import (
34
33
  from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
35
34
  from flwr.common.config import (
36
35
  flatten_dict,
36
+ get_metadata_from_config,
37
37
  parse_config_args,
38
38
  user_config_to_configrecord,
39
39
  )
@@ -45,11 +45,7 @@ from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
45
45
  from flwr.proto.exec_pb2_grpc import ExecStub
46
46
 
47
47
  from ..log import start_stream
48
- from ..utils import (
49
- init_channel,
50
- try_obtain_cli_auth_plugin,
51
- unauthenticated_exc_handler,
52
- )
48
+ from ..utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
53
49
 
54
50
  CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)
55
51
 
@@ -154,54 +150,57 @@ def _run_with_exec_api(
154
150
  stream: bool,
155
151
  output_format: str,
156
152
  ) -> None:
157
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
158
- channel = init_channel(app, federation_config, auth_plugin)
159
- stub = ExecStub(channel)
160
-
161
- fab_path, fab_hash = build(app)
162
- content = Path(fab_path).read_bytes()
163
- fab_id, fab_version = get_fab_metadata(Path(fab_path))
153
+ channel = None
154
+ try:
155
+ auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
156
+ channel = init_channel(app, federation_config, auth_plugin)
157
+ stub = ExecStub(channel)
164
158
 
165
- # Delete FAB file once the bytes is computed
166
- Path(fab_path).unlink()
159
+ fab_bytes, fab_hash, config = build_fab(app)
160
+ fab_id, fab_version = get_metadata_from_config(config)
167
161
 
168
- fab = Fab(fab_hash, content)
162
+ fab = Fab(fab_hash, fab_bytes)
169
163
 
170
- # Construct a `ConfigRecord` out of a flattened `UserConfig`
171
- fed_conf = flatten_dict(federation_config.get("options", {}))
172
- c_record = user_config_to_configrecord(fed_conf)
164
+ # Construct a `ConfigRecord` out of a flattened `UserConfig`
165
+ fed_config = flatten_dict(federation_config.get("options", {}))
166
+ c_record = user_config_to_configrecord(fed_config)
173
167
 
174
- req = StartRunRequest(
175
- fab=fab_to_proto(fab),
176
- override_config=user_config_to_proto(parse_config_args(config_overrides)),
177
- federation_options=config_record_to_proto(c_record),
178
- )
179
- with unauthenticated_exc_handler():
180
- res = stub.StartRun(req)
181
-
182
- if res.HasField("run_id"):
183
- typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
184
- else:
185
- typer.secho("❌ Failed to start run", fg=typer.colors.RED)
186
- raise typer.Exit(code=1)
187
-
188
- if output_format == CliOutputFormat.JSON:
189
- run_output = json.dumps(
190
- {
191
- "success": res.HasField("run_id"),
192
- "run-id": res.run_id if res.HasField("run_id") else None,
193
- "fab-id": fab_id,
194
- "fab-name": fab_id.rsplit("/", maxsplit=1)[-1],
195
- "fab-version": fab_version,
196
- "fab-hash": fab_hash[:8],
197
- "fab-filename": fab_path,
198
- }
168
+ req = StartRunRequest(
169
+ fab=fab_to_proto(fab),
170
+ override_config=user_config_to_proto(parse_config_args(config_overrides)),
171
+ federation_options=config_record_to_proto(c_record),
199
172
  )
200
- restore_output()
201
- Console().print_json(run_output)
173
+ with flwr_cli_grpc_exc_handler():
174
+ res = stub.StartRun(req)
202
175
 
203
- if stream:
204
- start_stream(res.run_id, channel, CONN_REFRESH_PERIOD)
176
+ if res.HasField("run_id"):
177
+ typer.secho(
178
+ f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN
179
+ )
180
+ else:
181
+ typer.secho("❌ Failed to start run", fg=typer.colors.RED)
182
+ raise typer.Exit(code=1)
183
+
184
+ if output_format == CliOutputFormat.JSON:
185
+ run_output = json.dumps(
186
+ {
187
+ "success": res.HasField("run_id"),
188
+ "run-id": res.run_id if res.HasField("run_id") else None,
189
+ "fab-id": fab_id,
190
+ "fab-name": fab_id.rsplit("/", maxsplit=1)[-1],
191
+ "fab-version": fab_version,
192
+ "fab-hash": fab_hash[:8],
193
+ "fab-filename": get_fab_filename(config, fab_hash),
194
+ }
195
+ )
196
+ restore_output()
197
+ Console().print_json(run_output)
198
+
199
+ if stream:
200
+ start_stream(res.run_id, channel, CONN_REFRESH_PERIOD)
201
+ finally:
202
+ if channel:
203
+ channel.close()
205
204
 
206
205
 
207
206
  def _run_without_exec_api(
flwr/cli/stop.py CHANGED
@@ -35,7 +35,7 @@ from flwr.common.logger import print_json_error, redirect_output, restore_output
35
35
  from flwr.proto.exec_pb2 import StopRunRequest, StopRunResponse # pylint: disable=E0611
36
36
  from flwr.proto.exec_pb2_grpc import ExecStub
37
37
 
38
- from .utils import init_channel, try_obtain_cli_auth_plugin, unauthenticated_exc_handler
38
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
39
39
 
40
40
 
41
41
  def stop( # pylint: disable=R0914
@@ -122,7 +122,7 @@ def stop( # pylint: disable=R0914
122
122
 
123
123
  def _stop_run(stub: ExecStub, run_id: int, output_format: str) -> None:
124
124
  """Stop a run."""
125
- with unauthenticated_exc_handler():
125
+ with flwr_cli_grpc_exc_handler():
126
126
  response: StopRunResponse = stub.StopRun(request=StopRunRequest(run_id=run_id))
127
127
  if response.success:
128
128
  typer.secho(f"✅ Run {run_id} successfully stopped.", fg=typer.colors.GREEN)
flwr/cli/utils.py CHANGED
@@ -28,7 +28,12 @@ import typer
28
28
 
29
29
  from flwr.cli.cli_user_auth_interceptor import CliUserAuthInterceptor
30
30
  from flwr.common.auth_plugin import CliAuthPlugin
31
- from flwr.common.constant import AUTH_TYPE_JSON_KEY, CREDENTIALS_DIR, FLWR_DIR
31
+ from flwr.common.constant import (
32
+ AUTH_TYPE_JSON_KEY,
33
+ CREDENTIALS_DIR,
34
+ FLWR_DIR,
35
+ RUN_ID_NOT_FOUND_MESSAGE,
36
+ )
32
37
  from flwr.common.grpc import (
33
38
  GRPC_MAX_MESSAGE_LENGTH,
34
39
  create_channel,
@@ -288,11 +293,12 @@ def init_channel(
288
293
 
289
294
 
290
295
  @contextmanager
291
- def unauthenticated_exc_handler() -> Iterator[None]:
292
- """Context manager to handle gRPC UNAUTHENTICATED errors.
296
+ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
297
+ """Context manager to handle specific gRPC errors.
293
298
 
294
- It catches grpc.RpcError exceptions with UNAUTHENTICATED status, informs the user,
295
- and exits the application. All other exceptions will be allowed to escape.
299
+ It catches grpc.RpcError exceptions with UNAUTHENTICATED, UNIMPLEMENTED,
300
+ UNAVAILABLE, and PERMISSION_DENIED statuses, informs the user, and exits the
301
+ application. All other exceptions will be allowed to escape.
296
302
  """
297
303
  try:
298
304
  yield
@@ -312,4 +318,31 @@ def unauthenticated_exc_handler() -> Iterator[None]:
312
318
  bold=True,
313
319
  )
314
320
  raise typer.Exit(code=1) from None
321
+ if e.code() == grpc.StatusCode.PERMISSION_DENIED:
322
+ typer.secho(
323
+ "❌ Permission denied.",
324
+ fg=typer.colors.RED,
325
+ bold=True,
326
+ )
327
+ # pylint: disable=E1101
328
+ typer.secho(e.details(), fg=typer.colors.RED, bold=True)
329
+ raise typer.Exit(code=1) from None
330
+ if e.code() == grpc.StatusCode.UNAVAILABLE:
331
+ typer.secho(
332
+ "Connection to the SuperLink is unavailable. Please check your network "
333
+ "connection and 'address' in the federation configuration.",
334
+ fg=typer.colors.RED,
335
+ bold=True,
336
+ )
337
+ raise typer.Exit(code=1) from None
338
+ if (
339
+ e.code() == grpc.StatusCode.NOT_FOUND
340
+ and e.details() == RUN_ID_NOT_FOUND_MESSAGE
341
+ ):
342
+ typer.secho(
343
+ "❌ Run ID not found.",
344
+ fg=typer.colors.RED,
345
+ bold=True,
346
+ )
347
+ raise typer.Exit(code=1) from None
315
348
  raise
flwr/client/__init__.py CHANGED
@@ -15,8 +15,8 @@
15
15
  """Flower client."""
16
16
 
17
17
 
18
- from .app import start_client as start_client
19
- from .app import start_numpy_client as start_numpy_client
18
+ from ..compat.client.app import start_client as start_client # Deprecated
19
+ from ..compat.client.app import start_numpy_client as start_numpy_client # Deprecated
20
20
  from .client import Client as Client
21
21
  from .client_app import ClientApp as ClientApp
22
22
  from .numpy_client import NumPyClient as NumPyClient
flwr/client/client_app.py CHANGED
@@ -20,6 +20,7 @@ from collections.abc import Iterator
20
20
  from contextlib import contextmanager
21
21
  from typing import Callable, Optional
22
22
 
23
+ from flwr.app.metadata import validate_message_type
23
24
  from flwr.client.client import Client
24
25
  from flwr.client.message_handler.message_handler import (
25
26
  handle_legacy_message_from_msgtype,
@@ -28,7 +29,6 @@ from flwr.client.mod.utils import make_ffn
28
29
  from flwr.client.typing import ClientFnExt, Mod
29
30
  from flwr.common import Context, Message, MessageType
30
31
  from flwr.common.logger import warn_deprecated_feature
31
- from flwr.common.message import validate_message_type
32
32
 
33
33
  from .typing import ClientAppCallable
34
34
 
@@ -13,10 +13,3 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  """Flower AppIO service."""
16
-
17
-
18
- from .app import flwr_clientapp as flwr_clientapp
19
-
20
- __all__ = [
21
- "flwr_clientapp",
22
- ]
@@ -29,6 +29,7 @@ from flwr.common.logger import log
29
29
  from flwr.common.message import Message
30
30
  from flwr.common.retry_invoker import RetryInvoker
31
31
  from flwr.common.typing import Fab, Run
32
+ from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
32
33
 
33
34
 
34
35
  @contextmanager
@@ -43,12 +44,15 @@ def grpc_adapter( # pylint: disable=R0913,too-many-positional-arguments
43
44
  ] = None,
44
45
  ) -> Iterator[
45
46
  tuple[
46
- Callable[[], Optional[Message]],
47
- Callable[[Message], None],
48
- Optional[Callable[[], Optional[int]]],
49
- Optional[Callable[[], None]],
50
- Optional[Callable[[int], Run]],
51
- Optional[Callable[[str, int], Fab]],
47
+ Callable[[], Optional[tuple[Message, ObjectTree]]],
48
+ Callable[[Message, ObjectTree], set[str]],
49
+ Callable[[], Optional[int]],
50
+ Callable[[], None],
51
+ Callable[[int], Run],
52
+ Callable[[str, int], Fab],
53
+ Callable[[int, str], bytes],
54
+ Callable[[int, str, bytes], None],
55
+ Callable[[int, str], None],
52
56
  ]
53
57
  ]:
54
58
  """Primitives for request/response-based interaction with a server via GrpcAdapter.
@@ -77,12 +81,15 @@ def grpc_adapter( # pylint: disable=R0913,too-many-positional-arguments
77
81
 
78
82
  Returns
79
83
  -------
80
- receive : Callable
81
- send : Callable
84
+ receive : Callable[[], Optional[tuple[Message, ObjectTree]]]
85
+ send : Callable[[Message, ObjectTree], set[str]]
82
86
  create_node : Optional[Callable]
83
87
  delete_node : Optional[Callable]
84
88
  get_run : Optional[Callable]
85
89
  get_fab : Optional[Callable]
90
+ pull_object : Callable[[str], bytes]
91
+ push_object : Callable[[str, bytes], None]
92
+ confirm_message_received : Callable[[str], None]
86
93
  """
87
94
  if authentication_keys is not None:
88
95
  log(ERROR, "Client authentication is not supported for this transport type.")
@@ -15,11 +15,8 @@
15
15
  """Contextmanager for a gRPC request-response channel to the Flower server."""
16
16
 
17
17
 
18
- import random
19
- import threading
20
18
  from collections.abc import Iterator, Sequence
21
19
  from contextlib import contextmanager
22
- from copy import copy
23
20
  from logging import ERROR
24
21
  from pathlib import Path
25
22
  from typing import Callable, Optional, Union, cast
@@ -27,19 +24,18 @@ from typing import Callable, Optional, Union, cast
27
24
  import grpc
28
25
  from cryptography.hazmat.primitives.asymmetric import ec
29
26
 
30
- from flwr.client.heartbeat import start_ping_loop
31
- from flwr.client.message_handler.message_handler import validate_out_message
32
27
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
33
- from flwr.common.constant import (
34
- PING_BASE_MULTIPLIER,
35
- PING_CALL_TIMEOUT,
36
- PING_DEFAULT_INTERVAL,
37
- PING_RANDOM_RANGE,
38
- )
28
+ from flwr.common.constant import HEARTBEAT_CALL_TIMEOUT, HEARTBEAT_DEFAULT_INTERVAL
39
29
  from flwr.common.grpc import create_channel, on_channel_state_change
30
+ from flwr.common.heartbeat import HeartbeatSender
31
+ from flwr.common.inflatable_protobuf_utils import (
32
+ make_confirm_message_received_fn_protobuf,
33
+ make_pull_object_fn_protobuf,
34
+ make_push_object_fn_protobuf,
35
+ )
40
36
  from flwr.common.logger import log
41
- from flwr.common.message import Message, Metadata
42
- from flwr.common.retry_invoker import RetryInvoker
37
+ from flwr.common.message import Message, remove_content_from_message
38
+ from flwr.common.retry_invoker import RetryInvoker, _wrap_stub
43
39
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
44
40
  generate_key_pairs,
45
41
  )
@@ -49,13 +45,17 @@ from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=
49
45
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
50
46
  CreateNodeRequest,
51
47
  DeleteNodeRequest,
52
- PingRequest,
53
- PingResponse,
54
48
  PullMessagesRequest,
55
49
  PullMessagesResponse,
56
50
  PushMessagesRequest,
51
+ PushMessagesResponse,
57
52
  )
58
53
  from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
54
+ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
55
+ SendNodeHeartbeatRequest,
56
+ SendNodeHeartbeatResponse,
57
+ )
58
+ from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
59
59
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
60
60
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
61
61
 
@@ -76,12 +76,15 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
76
76
  adapter_cls: Optional[Union[type[FleetStub], type[GrpcAdapter]]] = None,
77
77
  ) -> Iterator[
78
78
  tuple[
79
- Callable[[], Optional[Message]],
80
- Callable[[Message], None],
81
- Optional[Callable[[], Optional[int]]],
82
- Optional[Callable[[], None]],
83
- Optional[Callable[[int], Run]],
84
- Optional[Callable[[str, int], Fab]],
79
+ Callable[[], Optional[tuple[Message, ObjectTree]]],
80
+ Callable[[Message, ObjectTree], set[str]],
81
+ Callable[[], Optional[int]],
82
+ Callable[[], None],
83
+ Callable[[int], Run],
84
+ Callable[[str, int], Fab],
85
+ Callable[[int, str], bytes],
86
+ Callable[[int, str, bytes], None],
87
+ Callable[[int, str], None],
85
88
  ]
86
89
  ]:
87
90
  """Primitives for request/response-based interaction with a server.
@@ -124,6 +127,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
124
127
  create_node : Optional[Callable]
125
128
  delete_node : Optional[Callable]
126
129
  get_run : Optional[Callable]
130
+ pull_object : Callable[[str], bytes]
131
+ push_object : Callable[[str, bytes], None]
132
+ confirm_message_received : Callable[[str], None]
127
133
  """
128
134
  if isinstance(root_certificates, str):
129
135
  root_certificates = Path(root_certificates).read_bytes()
@@ -149,10 +155,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
149
155
  if adapter_cls is None:
150
156
  adapter_cls = FleetStub
151
157
  stub = adapter_cls(channel)
152
- metadata: Optional[Metadata] = None
153
158
  node: Optional[Node] = None
154
- ping_thread: Optional[threading.Thread] = None
155
- ping_stop_event = threading.Event()
156
159
 
157
160
  def _should_giveup_fn(e: Exception) -> bool:
158
161
  if e.code() == grpc.StatusCode.PERMISSION_DENIED: # type: ignore
@@ -165,46 +168,58 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
165
168
  # If the status code is PERMISSION_DENIED, additionally raise RunNotRunningException
166
169
  retry_invoker.should_giveup = _should_giveup_fn
167
170
 
171
+ # Wrap stub
172
+ _wrap_stub(stub, retry_invoker)
168
173
  ###########################################################################
169
- # ping/create_node/delete_node/receive/send/get_run functions
174
+ # send_node_heartbeat/create_node/delete_node/receive/send/get_run functions
170
175
  ###########################################################################
171
176
 
172
- def ping() -> None:
177
+ def send_node_heartbeat() -> bool:
173
178
  # Get Node
174
179
  if node is None:
175
180
  log(ERROR, "Node instance missing")
176
- return
181
+ return False
177
182
 
178
- # Construct the ping request
179
- req = PingRequest(node=node, ping_interval=PING_DEFAULT_INTERVAL)
183
+ # Construct the heartbeat request
184
+ req = SendNodeHeartbeatRequest(
185
+ node=node, heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
186
+ )
180
187
 
181
188
  # Call FleetAPI
182
- res: PingResponse = stub.Ping(req, timeout=PING_CALL_TIMEOUT)
189
+ try:
190
+ res: SendNodeHeartbeatResponse = stub.SendNodeHeartbeat(
191
+ req, timeout=HEARTBEAT_CALL_TIMEOUT
192
+ )
193
+ except grpc.RpcError as e:
194
+ status_code = e.code()
195
+ if status_code == grpc.StatusCode.UNAVAILABLE:
196
+ return False
197
+ if status_code == grpc.StatusCode.DEADLINE_EXCEEDED:
198
+ return False
199
+ raise
183
200
 
184
201
  # Check if success
185
202
  if not res.success:
186
- raise RuntimeError("Ping failed unexpectedly.")
203
+ raise RuntimeError(
204
+ "Heartbeat failed unexpectedly. The SuperLink does not "
205
+ "recognize this SuperNode."
206
+ )
207
+ return True
187
208
 
188
- # Wait
189
- rd = random.uniform(*PING_RANDOM_RANGE)
190
- next_interval: float = PING_DEFAULT_INTERVAL - PING_CALL_TIMEOUT
191
- next_interval *= PING_BASE_MULTIPLIER + rd
192
- if not ping_stop_event.is_set():
193
- ping_stop_event.wait(next_interval)
209
+ heartbeat_sender = HeartbeatSender(send_node_heartbeat)
194
210
 
195
211
  def create_node() -> Optional[int]:
196
212
  """Set create_node."""
197
213
  # Call FleetAPI
198
- create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
199
- create_node_response = retry_invoker.invoke(
200
- stub.CreateNode,
201
- request=create_node_request,
214
+ create_node_request = CreateNodeRequest(
215
+ heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
202
216
  )
217
+ create_node_response = stub.CreateNode(request=create_node_request)
203
218
 
204
- # Remember the node and the ping-loop thread
205
- nonlocal node, ping_thread
219
+ # Remember the node and start the heartbeat sender
220
+ nonlocal node
206
221
  node = cast(Node, create_node_response.node)
207
- ping_thread = start_ping_loop(ping, ping_stop_event)
222
+ heartbeat_sender.start()
208
223
  return node.node_id
209
224
 
210
225
  def delete_node() -> None:
@@ -215,83 +230,67 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
215
230
  log(ERROR, "Node instance missing")
216
231
  return
217
232
 
218
- # Stop the ping-loop thread
219
- ping_stop_event.set()
233
+ # Stop the heartbeat sender
234
+ heartbeat_sender.stop()
220
235
 
221
236
  # Call FleetAPI
222
237
  delete_node_request = DeleteNodeRequest(node=node)
223
- retry_invoker.invoke(stub.DeleteNode, request=delete_node_request)
238
+ stub.DeleteNode(request=delete_node_request)
224
239
 
225
240
  # Cleanup
226
241
  node = None
227
242
 
228
- def receive() -> Optional[Message]:
229
- """Receive next message from server."""
243
+ def receive() -> Optional[tuple[Message, ObjectTree]]:
244
+ """Pull a message with its ObjectTree from SuperLink."""
230
245
  # Get Node
231
246
  if node is None:
232
247
  log(ERROR, "Node instance missing")
233
248
  return None
234
249
 
235
- # Request instructions (message) from server
250
+ # Try to pull a message with its object tree from SuperLink
236
251
  request = PullMessagesRequest(node=node)
237
- response: PullMessagesResponse = retry_invoker.invoke(
238
- stub.PullMessages, request=request
239
- )
252
+ response: PullMessagesResponse = stub.PullMessages(request=request)
240
253
 
241
- # Get the current Messages
242
- message_proto = (
243
- None if len(response.messages_list) == 0 else response.messages_list[0]
244
- )
254
+ # If no messages are available, return None
255
+ if len(response.messages_list) == 0:
256
+ return None
245
257
 
246
- # Discard the current message if not valid
247
- if message_proto is not None and not (
248
- message_proto.metadata.dst_node_id == node.node_id
249
- ):
250
- message_proto = None
258
+ # Get the current Message and its object tree
259
+ message_proto = response.messages_list[0]
260
+ object_tree = response.message_object_trees[0]
251
261
 
252
262
  # Construct the Message
253
- in_message = message_from_proto(message_proto) if message_proto else None
254
-
255
- # Remember `metadata` of the in message
256
- nonlocal metadata
257
- metadata = copy(in_message.metadata) if in_message else None
263
+ in_message = message_from_proto(message_proto)
258
264
 
259
- # Return the message if available
260
- return in_message
265
+ # Return the Message and its object tree
266
+ return in_message, object_tree
261
267
 
262
- def send(message: Message) -> None:
263
- """Send message reply to server."""
268
+ def send(message: Message, object_tree: ObjectTree) -> set[str]:
269
+ """Send the message with its ObjectTree to SuperLink."""
264
270
  # Get Node
265
271
  if node is None:
266
272
  log(ERROR, "Node instance missing")
267
- return
268
-
269
- # Get the metadata of the incoming message
270
- nonlocal metadata
271
- if metadata is None:
272
- log(ERROR, "No current message")
273
- return
273
+ return set()
274
274
 
275
- # Validate out message
276
- if not validate_out_message(message, metadata):
277
- log(ERROR, "Invalid out message")
278
- return
275
+ # Remove the content from the message if it has
276
+ if message.has_content():
277
+ message = remove_content_from_message(message)
279
278
 
280
- # Serialize Message
281
- message_proto = message_to_proto(message=message)
282
- request = PushMessagesRequest(node=node, messages_list=[message_proto])
283
- _ = retry_invoker.invoke(stub.PushMessages, request)
279
+ # Send the message with its ObjectTree to SuperLink
280
+ request = PushMessagesRequest(
281
+ node=node,
282
+ messages_list=[message_to_proto(message)],
283
+ message_object_trees=[object_tree],
284
+ )
285
+ response: PushMessagesResponse = stub.PushMessages(request=request)
284
286
 
285
- # Cleanup
286
- metadata = None
287
+ # Get and return the object IDs to push
288
+ return set(response.objects_to_push)
287
289
 
288
290
  def get_run(run_id: int) -> Run:
289
291
  # Call FleetAPI
290
292
  get_run_request = GetRunRequest(node=node, run_id=run_id)
291
- get_run_response: GetRunResponse = retry_invoker.invoke(
292
- stub.GetRun,
293
- request=get_run_request,
294
- )
293
+ get_run_response: GetRunResponse = stub.GetRun(request=get_run_request)
295
294
 
296
295
  # Return fab_id and fab_version
297
296
  return run_from_proto(get_run_response.run)
@@ -299,16 +298,62 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
299
298
  def get_fab(fab_hash: str, run_id: int) -> Fab:
300
299
  # Call FleetAPI
301
300
  get_fab_request = GetFabRequest(node=node, hash_str=fab_hash, run_id=run_id)
302
- get_fab_response: GetFabResponse = retry_invoker.invoke(
303
- stub.GetFab,
304
- request=get_fab_request,
305
- )
301
+ get_fab_response: GetFabResponse = stub.GetFab(request=get_fab_request)
306
302
 
307
303
  return Fab(get_fab_response.fab.hash_str, get_fab_response.fab.content)
308
304
 
305
+ def pull_object(run_id: int, object_id: str) -> bytes:
306
+ """Pull the object from the SuperLink."""
307
+ # Check Node
308
+ if node is None:
309
+ raise RuntimeError("Node instance missing")
310
+
311
+ fn = make_pull_object_fn_protobuf(
312
+ pull_object_protobuf=stub.PullObject,
313
+ node=node,
314
+ run_id=run_id,
315
+ )
316
+ return fn(object_id)
317
+
318
+ def push_object(run_id: int, object_id: str, contents: bytes) -> None:
319
+ """Push the object to the SuperLink."""
320
+ # Check Node
321
+ if node is None:
322
+ raise RuntimeError("Node instance missing")
323
+
324
+ fn = make_push_object_fn_protobuf(
325
+ push_object_protobuf=stub.PushObject,
326
+ node=node,
327
+ run_id=run_id,
328
+ )
329
+ fn(object_id, contents)
330
+
331
+ def confirm_message_received(run_id: int, object_id: str) -> None:
332
+ """Confirm that the message has been received."""
333
+ # Check Node
334
+ if node is None:
335
+ raise RuntimeError("Node instance missing")
336
+
337
+ fn = make_confirm_message_received_fn_protobuf(
338
+ confirm_message_received_protobuf=stub.ConfirmMessageReceived,
339
+ node=node,
340
+ run_id=run_id,
341
+ )
342
+ fn(object_id)
343
+
309
344
  try:
310
345
  # Yield methods
311
- yield (receive, send, create_node, delete_node, get_run, get_fab)
346
+ yield (
347
+ receive,
348
+ send,
349
+ create_node,
350
+ delete_node,
351
+ get_run,
352
+ get_fab,
353
+ pull_object,
354
+ push_object,
355
+ confirm_message_received,
356
+ )
312
357
  except Exception as exc: # pylint: disable=broad-except
313
358
  log(ERROR, exc)
314
359
  # Cleanup