flwr 1.19.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. flwr/__init__.py +4 -1
  2. flwr/app/__init__.py +28 -0
  3. flwr/app/exception.py +31 -0
  4. flwr/cli/auth_plugin/oidc_cli_plugin.py +4 -4
  5. flwr/cli/build.py +15 -5
  6. flwr/cli/cli_user_auth_interceptor.py +1 -1
  7. flwr/cli/config_utils.py +3 -3
  8. flwr/cli/constant.py +25 -8
  9. flwr/cli/log.py +9 -9
  10. flwr/cli/login/login.py +3 -3
  11. flwr/cli/ls.py +5 -5
  12. flwr/cli/new/new.py +23 -4
  13. flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
  14. flwr/cli/new/templates/app/README.md.tpl +5 -0
  15. flwr/cli/new/templates/app/code/__init__.pytorch_msg_api.py.tpl +1 -0
  16. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +80 -0
  17. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +41 -0
  18. flwr/cli/new/templates/app/code/task.pytorch_msg_api.py.tpl +98 -0
  19. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -3
  20. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
  21. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
  22. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
  23. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
  24. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
  25. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
  26. flwr/cli/new/templates/app/pyproject.pytorch_msg_api.toml.tpl +53 -0
  27. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
  28. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
  29. flwr/cli/run/run.py +53 -50
  30. flwr/cli/stop.py +7 -4
  31. flwr/cli/utils.py +29 -11
  32. flwr/client/grpc_adapter_client/connection.py +11 -4
  33. flwr/client/grpc_rere_client/connection.py +93 -129
  34. flwr/client/rest_client/connection.py +134 -164
  35. flwr/clientapp/__init__.py +10 -0
  36. flwr/clientapp/mod/__init__.py +26 -0
  37. flwr/clientapp/mod/centraldp_mods.py +132 -0
  38. flwr/common/args.py +20 -6
  39. flwr/common/auth_plugin/__init__.py +4 -4
  40. flwr/common/auth_plugin/auth_plugin.py +7 -7
  41. flwr/common/constant.py +26 -5
  42. flwr/common/event_log_plugin/event_log_plugin.py +1 -1
  43. flwr/common/exit/__init__.py +4 -0
  44. flwr/common/exit/exit.py +8 -1
  45. flwr/common/exit/exit_code.py +42 -8
  46. flwr/common/exit/exit_handler.py +62 -0
  47. flwr/common/{exit_handlers.py → exit/signal_handler.py} +20 -37
  48. flwr/common/grpc.py +1 -1
  49. flwr/common/{inflatable_grpc_utils.py → inflatable_protobuf_utils.py} +52 -10
  50. flwr/common/inflatable_utils.py +191 -24
  51. flwr/common/logger.py +1 -1
  52. flwr/common/record/array.py +101 -22
  53. flwr/common/record/arraychunk.py +59 -0
  54. flwr/common/retry_invoker.py +30 -11
  55. flwr/common/serde.py +0 -28
  56. flwr/common/telemetry.py +4 -0
  57. flwr/compat/client/app.py +14 -31
  58. flwr/compat/server/app.py +2 -2
  59. flwr/proto/appio_pb2.py +51 -0
  60. flwr/proto/appio_pb2.pyi +195 -0
  61. flwr/proto/appio_pb2_grpc.py +4 -0
  62. flwr/proto/appio_pb2_grpc.pyi +4 -0
  63. flwr/proto/clientappio_pb2.py +4 -19
  64. flwr/proto/clientappio_pb2.pyi +0 -125
  65. flwr/proto/clientappio_pb2_grpc.py +269 -29
  66. flwr/proto/clientappio_pb2_grpc.pyi +114 -21
  67. flwr/proto/control_pb2.py +62 -0
  68. flwr/proto/{exec_pb2_grpc.py → control_pb2_grpc.py} +54 -54
  69. flwr/proto/{exec_pb2_grpc.pyi → control_pb2_grpc.pyi} +28 -28
  70. flwr/proto/fleet_pb2.py +12 -20
  71. flwr/proto/fleet_pb2.pyi +6 -36
  72. flwr/proto/serverappio_pb2.py +8 -31
  73. flwr/proto/serverappio_pb2.pyi +0 -152
  74. flwr/proto/serverappio_pb2_grpc.py +107 -38
  75. flwr/proto/serverappio_pb2_grpc.pyi +47 -20
  76. flwr/proto/simulationio_pb2.py +4 -11
  77. flwr/proto/simulationio_pb2.pyi +0 -58
  78. flwr/proto/simulationio_pb2_grpc.py +129 -27
  79. flwr/proto/simulationio_pb2_grpc.pyi +52 -13
  80. flwr/server/app.py +130 -153
  81. flwr/server/fleet_event_log_interceptor.py +4 -0
  82. flwr/server/grid/grpc_grid.py +94 -54
  83. flwr/server/grid/inmemory_grid.py +1 -0
  84. flwr/server/serverapp/app.py +165 -144
  85. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +8 -0
  86. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  87. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
  88. flwr/server/superlink/fleet/message_handler/message_handler.py +10 -16
  89. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -2
  90. flwr/server/superlink/fleet/vce/backend/raybackend.py +3 -1
  91. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  92. flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
  93. flwr/server/superlink/linkstate/linkstate.py +2 -1
  94. flwr/server/superlink/linkstate/sqlite_linkstate.py +45 -0
  95. flwr/server/superlink/serverappio/serverappio_grpc.py +2 -2
  96. flwr/server/superlink/serverappio/serverappio_servicer.py +95 -48
  97. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  98. flwr/server/superlink/simulation/simulationio_servicer.py +98 -22
  99. flwr/server/superlink/utils.py +0 -35
  100. flwr/serverapp/__init__.py +12 -0
  101. flwr/serverapp/dp_fixed_clipping.py +352 -0
  102. flwr/serverapp/exception.py +38 -0
  103. flwr/serverapp/strategy/__init__.py +38 -0
  104. flwr/serverapp/strategy/dp_fixed_clipping.py +352 -0
  105. flwr/serverapp/strategy/fedadagrad.py +162 -0
  106. flwr/serverapp/strategy/fedadam.py +181 -0
  107. flwr/serverapp/strategy/fedavg.py +295 -0
  108. flwr/serverapp/strategy/fedopt.py +218 -0
  109. flwr/serverapp/strategy/fedyogi.py +173 -0
  110. flwr/serverapp/strategy/result.py +105 -0
  111. flwr/serverapp/strategy/strategy.py +285 -0
  112. flwr/serverapp/strategy/strategy_utils.py +251 -0
  113. flwr/serverapp/strategy/strategy_utils_tests.py +304 -0
  114. flwr/simulation/app.py +159 -154
  115. flwr/simulation/run_simulation.py +17 -0
  116. flwr/supercore/app_utils.py +58 -0
  117. flwr/supercore/cli/__init__.py +22 -0
  118. flwr/supercore/cli/flower_superexec.py +141 -0
  119. flwr/supercore/corestate/__init__.py +22 -0
  120. flwr/supercore/corestate/corestate.py +81 -0
  121. flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
  122. flwr/supercore/grpc_health/__init__.py +25 -0
  123. flwr/supercore/grpc_health/health_server.py +53 -0
  124. flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
  125. flwr/supercore/license_plugin/__init__.py +22 -0
  126. flwr/supercore/license_plugin/license_plugin.py +26 -0
  127. flwr/supercore/object_store/in_memory_object_store.py +31 -31
  128. flwr/supercore/object_store/object_store.py +20 -42
  129. flwr/supercore/object_store/utils.py +43 -0
  130. flwr/{superexec → supercore/superexec}/__init__.py +1 -1
  131. flwr/supercore/superexec/plugin/__init__.py +28 -0
  132. flwr/supercore/superexec/plugin/base_exec_plugin.py +53 -0
  133. flwr/supercore/superexec/plugin/clientapp_exec_plugin.py +28 -0
  134. flwr/supercore/superexec/plugin/exec_plugin.py +71 -0
  135. flwr/supercore/superexec/plugin/serverapp_exec_plugin.py +28 -0
  136. flwr/supercore/superexec/plugin/simulation_exec_plugin.py +28 -0
  137. flwr/supercore/superexec/run_superexec.py +185 -0
  138. flwr/supercore/utils.py +32 -0
  139. flwr/superlink/servicer/__init__.py +15 -0
  140. flwr/superlink/servicer/control/__init__.py +22 -0
  141. flwr/{superexec/exec_event_log_interceptor.py → superlink/servicer/control/control_event_log_interceptor.py} +9 -5
  142. flwr/{superexec/exec_grpc.py → superlink/servicer/control/control_grpc.py} +39 -28
  143. flwr/superlink/servicer/control/control_license_interceptor.py +82 -0
  144. flwr/{superexec/exec_servicer.py → superlink/servicer/control/control_servicer.py} +79 -31
  145. flwr/{superexec/exec_user_auth_interceptor.py → superlink/servicer/control/control_user_auth_interceptor.py} +18 -10
  146. flwr/supernode/cli/flower_supernode.py +3 -7
  147. flwr/supernode/cli/flwr_clientapp.py +20 -16
  148. flwr/supernode/nodestate/in_memory_nodestate.py +13 -4
  149. flwr/supernode/nodestate/nodestate.py +3 -44
  150. flwr/supernode/runtime/run_clientapp.py +129 -115
  151. flwr/supernode/servicer/clientappio/__init__.py +1 -3
  152. flwr/supernode/servicer/clientappio/clientappio_servicer.py +217 -165
  153. flwr/supernode/start_client_internal.py +205 -148
  154. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/METADATA +5 -3
  155. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/RECORD +161 -117
  156. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/entry_points.txt +1 -0
  157. flwr/common/inflatable_rest_utils.py +0 -99
  158. flwr/proto/exec_pb2.py +0 -62
  159. flwr/superexec/app.py +0 -45
  160. flwr/superexec/deployment.py +0 -192
  161. flwr/superexec/executor.py +0 -100
  162. flwr/superexec/simulation.py +0 -130
  163. /flwr/proto/{exec_pb2.pyi → control_pb2.pyi} +0 -0
  164. /flwr/{server/superlink → supercore}/ffs/__init__.py +0 -0
  165. /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
  166. /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
  167. {flwr-1.19.0.dist-info → flwr-1.21.0.dist-info}/WHEEL +0 -0
flwr/cli/run/run.py CHANGED
@@ -30,7 +30,7 @@ from flwr.cli.config_utils import (
30
30
  process_loaded_project_config,
31
31
  validate_federation_in_project_config,
32
32
  )
33
- from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
33
+ from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE, RUN_CONFIG_HELP_MESSAGE
34
34
  from flwr.common.config import (
35
35
  flatten_dict,
36
36
  get_metadata_from_config,
@@ -41,8 +41,8 @@ from flwr.common.constant import CliOutputFormat
41
41
  from flwr.common.logger import print_json_error, redirect_output, restore_output
42
42
  from flwr.common.serde import config_record_to_proto, fab_to_proto, user_config_to_proto
43
43
  from flwr.common.typing import Fab
44
- from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
45
- from flwr.proto.exec_pb2_grpc import ExecStub
44
+ from flwr.proto.control_pb2 import StartRunRequest # pylint: disable=E0611
45
+ from flwr.proto.control_pb2_grpc import ControlStub
46
46
 
47
47
  from ..log import start_stream
48
48
  from ..utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
@@ -65,11 +65,7 @@ def run(
65
65
  typer.Option(
66
66
  "--run-config",
67
67
  "-c",
68
- help="Override run configuration values in the format:\n\n"
69
- "`--run-config 'key1=value1 key2=value2' --run-config 'key3=value3'`\n\n"
70
- "Values can be of any type supported in TOML, such as bool, int, "
71
- "float, or string. Ensure that the keys (`key1`, `key2`, `key3` "
72
- "in this example) exist in `pyproject.toml` for proper overriding.",
68
+ help=RUN_CONFIG_HELP_MESSAGE,
73
69
  ),
74
70
  ] = None,
75
71
  federation_config_overrides: Annotated[
@@ -112,7 +108,7 @@ def run(
112
108
  )
113
109
 
114
110
  if "address" in federation_config:
115
- _run_with_exec_api(
111
+ _run_with_control_api(
116
112
  app,
117
113
  federation,
118
114
  federation_config,
@@ -121,7 +117,7 @@ def run(
121
117
  output_format,
122
118
  )
123
119
  else:
124
- _run_without_exec_api(
120
+ _run_without_control_api(
125
121
  app, federation_config, run_config_overrides, federation
126
122
  )
127
123
  except (typer.Exit, Exception) as err: # pylint: disable=broad-except
@@ -142,7 +138,7 @@ def run(
142
138
 
143
139
 
144
140
  # pylint: disable-next=R0913, R0914, R0917
145
- def _run_with_exec_api(
141
+ def _run_with_control_api(
146
142
  app: Path,
147
143
  federation: str,
148
144
  federation_config: dict[str, Any],
@@ -150,53 +146,60 @@ def _run_with_exec_api(
150
146
  stream: bool,
151
147
  output_format: str,
152
148
  ) -> None:
153
- auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
154
- channel = init_channel(app, federation_config, auth_plugin)
155
- stub = ExecStub(channel)
149
+ channel = None
150
+ try:
151
+ auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
152
+ channel = init_channel(app, federation_config, auth_plugin)
153
+ stub = ControlStub(channel)
156
154
 
157
- fab_bytes, fab_hash, config = build_fab(app)
158
- fab_id, fab_version = get_metadata_from_config(config)
155
+ fab_bytes, fab_hash, config = build_fab(app)
156
+ fab_id, fab_version = get_metadata_from_config(config)
159
157
 
160
- fab = Fab(fab_hash, fab_bytes)
158
+ fab = Fab(fab_hash, fab_bytes)
161
159
 
162
- # Construct a `ConfigRecord` out of a flattened `UserConfig`
163
- fed_config = flatten_dict(federation_config.get("options", {}))
164
- c_record = user_config_to_configrecord(fed_config)
160
+ # Construct a `ConfigRecord` out of a flattened `UserConfig`
161
+ fed_config = flatten_dict(federation_config.get("options", {}))
162
+ c_record = user_config_to_configrecord(fed_config)
165
163
 
166
- req = StartRunRequest(
167
- fab=fab_to_proto(fab),
168
- override_config=user_config_to_proto(parse_config_args(config_overrides)),
169
- federation_options=config_record_to_proto(c_record),
170
- )
171
- with flwr_cli_grpc_exc_handler():
172
- res = stub.StartRun(req)
173
-
174
- if res.HasField("run_id"):
175
- typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
176
- else:
177
- typer.secho("❌ Failed to start run", fg=typer.colors.RED)
178
- raise typer.Exit(code=1)
179
-
180
- if output_format == CliOutputFormat.JSON:
181
- run_output = json.dumps(
182
- {
183
- "success": res.HasField("run_id"),
184
- "run-id": res.run_id if res.HasField("run_id") else None,
185
- "fab-id": fab_id,
186
- "fab-name": fab_id.rsplit("/", maxsplit=1)[-1],
187
- "fab-version": fab_version,
188
- "fab-hash": fab_hash[:8],
189
- "fab-filename": get_fab_filename(config, fab_hash),
190
- }
164
+ req = StartRunRequest(
165
+ fab=fab_to_proto(fab),
166
+ override_config=user_config_to_proto(parse_config_args(config_overrides)),
167
+ federation_options=config_record_to_proto(c_record),
191
168
  )
192
- restore_output()
193
- Console().print_json(run_output)
169
+ with flwr_cli_grpc_exc_handler():
170
+ res = stub.StartRun(req)
194
171
 
195
- if stream:
196
- start_stream(res.run_id, channel, CONN_REFRESH_PERIOD)
172
+ if res.HasField("run_id"):
173
+ typer.secho(
174
+ f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN
175
+ )
176
+ else:
177
+ typer.secho("❌ Failed to start run", fg=typer.colors.RED)
178
+ raise typer.Exit(code=1)
179
+
180
+ if output_format == CliOutputFormat.JSON:
181
+ run_output = json.dumps(
182
+ {
183
+ "success": res.HasField("run_id"),
184
+ "run-id": res.run_id if res.HasField("run_id") else None,
185
+ "fab-id": fab_id,
186
+ "fab-name": fab_id.rsplit("/", maxsplit=1)[-1],
187
+ "fab-version": fab_version,
188
+ "fab-hash": fab_hash[:8],
189
+ "fab-filename": get_fab_filename(config, fab_hash),
190
+ }
191
+ )
192
+ restore_output()
193
+ Console().print_json(run_output)
194
+
195
+ if stream:
196
+ start_stream(res.run_id, channel, CONN_REFRESH_PERIOD)
197
+ finally:
198
+ if channel:
199
+ channel.close()
197
200
 
198
201
 
199
- def _run_without_exec_api(
202
+ def _run_without_control_api(
200
203
  app: Optional[Path],
201
204
  federation_config: dict[str, Any],
202
205
  config_overrides: Optional[list[str]],
flwr/cli/stop.py CHANGED
@@ -32,8 +32,11 @@ from flwr.cli.config_utils import (
32
32
  from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
33
33
  from flwr.common.constant import FAB_CONFIG_FILE, CliOutputFormat
34
34
  from flwr.common.logger import print_json_error, redirect_output, restore_output
35
- from flwr.proto.exec_pb2 import StopRunRequest, StopRunResponse # pylint: disable=E0611
36
- from flwr.proto.exec_pb2_grpc import ExecStub
35
+ from flwr.proto.control_pb2 import ( # pylint: disable=E0611
36
+ StopRunRequest,
37
+ StopRunResponse,
38
+ )
39
+ from flwr.proto.control_pb2_grpc import ControlStub
37
40
 
38
41
  from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
39
42
 
@@ -88,7 +91,7 @@ def stop( # pylint: disable=R0914
88
91
  try:
89
92
  auth_plugin = try_obtain_cli_auth_plugin(app, federation, federation_config)
90
93
  channel = init_channel(app, federation_config, auth_plugin)
91
- stub = ExecStub(channel) # pylint: disable=unused-variable # noqa: F841
94
+ stub = ControlStub(channel) # pylint: disable=unused-variable # noqa: F841
92
95
 
93
96
  typer.secho(f"✋ Stopping run ID {run_id}...", fg=typer.colors.GREEN)
94
97
  _stop_run(stub=stub, run_id=run_id, output_format=output_format)
@@ -120,7 +123,7 @@ def stop( # pylint: disable=R0914
120
123
  captured_output.close()
121
124
 
122
125
 
123
- def _stop_run(stub: ExecStub, run_id: int, output_format: str) -> None:
126
+ def _stop_run(stub: ControlStub, run_id: int, output_format: str) -> None:
124
127
  """Stop a run."""
125
128
  with flwr_cli_grpc_exc_handler():
126
129
  response: StopRunResponse = stub.StopRun(request=StopRunRequest(run_id=run_id))
flwr/cli/utils.py CHANGED
@@ -32,6 +32,7 @@ from flwr.common.constant import (
32
32
  AUTH_TYPE_JSON_KEY,
33
33
  CREDENTIALS_DIR,
34
34
  FLWR_DIR,
35
+ NO_USER_AUTH_MESSAGE,
35
36
  RUN_ID_NOT_FOUND_MESSAGE,
36
37
  )
37
38
  from flwr.common.grpc import (
@@ -259,7 +260,7 @@ def try_obtain_cli_auth_plugin(
259
260
  def init_channel(
260
261
  app: Path, federation_config: dict[str, Any], auth_plugin: Optional[CliAuthPlugin]
261
262
  ) -> grpc.Channel:
262
- """Initialize gRPC channel to the Exec API."""
263
+ """Initialize gRPC channel to the Control API."""
263
264
  insecure, root_certificates_bytes = validate_certificate_in_federation_config(
264
265
  app, federation_config
265
266
  )
@@ -296,9 +297,9 @@ def init_channel(
296
297
  def flwr_cli_grpc_exc_handler() -> Iterator[None]:
297
298
  """Context manager to handle specific gRPC errors.
298
299
 
299
- It catches grpc.RpcError exceptions with UNAUTHENTICATED, UNIMPLEMENTED, and
300
- PERMISSION_DENIED statuses, informs the user, and exits the application. All other
301
- exceptions will be allowed to escape.
300
+ It catches grpc.RpcError exceptions with UNAUTHENTICATED, UNIMPLEMENTED,
301
+ UNAVAILABLE, and PERMISSION_DENIED statuses, informs the user, and exits the
302
+ application. All other exceptions will be allowed to escape.
302
303
  """
303
304
  try:
304
305
  yield
@@ -312,25 +313,42 @@ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
312
313
  )
313
314
  raise typer.Exit(code=1) from None
314
315
  if e.code() == grpc.StatusCode.UNIMPLEMENTED:
316
+ if e.details() == NO_USER_AUTH_MESSAGE: # pylint: disable=E1101
317
+ typer.secho(
318
+ "❌ User authentication is not enabled on this SuperLink.",
319
+ fg=typer.colors.RED,
320
+ bold=True,
321
+ )
322
+ else:
323
+ typer.secho(
324
+ "❌ The SuperLink cannot process this request. Please verify that "
325
+ "you set the address to its Control API endpoint correctly in your "
326
+ "`pyproject.toml`, and ensure that the Flower versions used by "
327
+ "the CLI and SuperLink are compatible.",
328
+ fg=typer.colors.RED,
329
+ bold=True,
330
+ )
331
+ raise typer.Exit(code=1) from None
332
+ if e.code() == grpc.StatusCode.PERMISSION_DENIED:
315
333
  typer.secho(
316
- "❌ User authentication is not enabled on this SuperLink.",
334
+ "❌ Permission denied.",
317
335
  fg=typer.colors.RED,
318
336
  bold=True,
319
337
  )
338
+ # pylint: disable-next=E1101
339
+ typer.secho(e.details(), fg=typer.colors.RED, bold=True)
320
340
  raise typer.Exit(code=1) from None
321
- if e.code() == grpc.StatusCode.PERMISSION_DENIED:
341
+ if e.code() == grpc.StatusCode.UNAVAILABLE:
322
342
  typer.secho(
323
- " Authorization failed. Please contact your administrator"
324
- " to check your permissions.",
343
+ "Connection to the SuperLink is unavailable. Please check your network "
344
+ "connection and 'address' in the federation configuration.",
325
345
  fg=typer.colors.RED,
326
346
  bold=True,
327
347
  )
328
- # pylint: disable=E1101
329
- typer.secho(e.details(), fg=typer.colors.RED, bold=True)
330
348
  raise typer.Exit(code=1) from None
331
349
  if (
332
350
  e.code() == grpc.StatusCode.NOT_FOUND
333
- and e.details() == RUN_ID_NOT_FOUND_MESSAGE
351
+ and e.details() == RUN_ID_NOT_FOUND_MESSAGE # pylint: disable=E1101
334
352
  ):
335
353
  typer.secho(
336
354
  "❌ Run ID not found.",
@@ -29,6 +29,7 @@ from flwr.common.logger import log
29
29
  from flwr.common.message import Message
30
30
  from flwr.common.retry_invoker import RetryInvoker
31
31
  from flwr.common.typing import Fab, Run
32
+ from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
32
33
 
33
34
 
34
35
  @contextmanager
@@ -43,12 +44,15 @@ def grpc_adapter( # pylint: disable=R0913,too-many-positional-arguments
43
44
  ] = None,
44
45
  ) -> Iterator[
45
46
  tuple[
46
- Callable[[], Optional[Message]],
47
- Callable[[Message], None],
47
+ Callable[[], Optional[tuple[Message, ObjectTree]]],
48
+ Callable[[Message, ObjectTree], set[str]],
48
49
  Callable[[], Optional[int]],
49
50
  Callable[[], None],
50
51
  Callable[[int], Run],
51
52
  Callable[[str, int], Fab],
53
+ Callable[[int, str], bytes],
54
+ Callable[[int, str, bytes], None],
55
+ Callable[[int, str], None],
52
56
  ]
53
57
  ]:
54
58
  """Primitives for request/response-based interaction with a server via GrpcAdapter.
@@ -77,12 +81,15 @@ def grpc_adapter( # pylint: disable=R0913,too-many-positional-arguments
77
81
 
78
82
  Returns
79
83
  -------
80
- receive : Callable
81
- send : Callable
84
+ receive : Callable[[], Optional[tuple[Message, ObjectTree]]]
85
+ send : Callable[[Message, ObjectTree], set[str]]
82
86
  create_node : Optional[Callable]
83
87
  delete_node : Optional[Callable]
84
88
  get_run : Optional[Callable]
85
89
  get_fab : Optional[Callable]
90
+ pull_object : Callable[[str], bytes]
91
+ push_object : Callable[[str, bytes], None]
92
+ confirm_message_received : Callable[[str], None]
86
93
  """
87
94
  if authentication_keys is not None:
88
95
  log(ERROR, "Client authentication is not supported for this transport type.")
@@ -17,33 +17,21 @@
17
17
 
18
18
  from collections.abc import Iterator, Sequence
19
19
  from contextlib import contextmanager
20
- from copy import copy
21
- from logging import DEBUG, ERROR
20
+ from logging import ERROR
22
21
  from pathlib import Path
23
22
  from typing import Callable, Optional, Union, cast
24
23
 
25
24
  import grpc
26
25
  from cryptography.hazmat.primitives.asymmetric import ec
27
26
 
28
- from flwr.app.metadata import Metadata
29
- from flwr.client.message_handler.message_handler import validate_out_message
30
27
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
31
28
  from flwr.common.constant import HEARTBEAT_CALL_TIMEOUT, HEARTBEAT_DEFAULT_INTERVAL
32
29
  from flwr.common.grpc import create_channel, on_channel_state_change
33
30
  from flwr.common.heartbeat import HeartbeatSender
34
- from flwr.common.inflatable import (
35
- get_all_nested_objects,
36
- get_object_tree,
37
- no_object_id_recompute,
38
- )
39
- from flwr.common.inflatable_grpc_utils import (
40
- make_pull_object_fn_grpc,
41
- make_push_object_fn_grpc,
42
- )
43
- from flwr.common.inflatable_utils import (
44
- inflate_object_from_contents,
45
- pull_objects,
46
- push_objects,
31
+ from flwr.common.inflatable_protobuf_utils import (
32
+ make_confirm_message_received_fn_protobuf,
33
+ make_pull_object_fn_protobuf,
34
+ make_push_object_fn_protobuf,
47
35
  )
48
36
  from flwr.common.logger import log
49
37
  from flwr.common.message import Message, remove_content_from_message
@@ -51,8 +39,8 @@ from flwr.common.retry_invoker import RetryInvoker, _wrap_stub
51
39
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
52
40
  generate_key_pairs,
53
41
  )
54
- from flwr.common.serde import message_to_proto, run_from_proto
55
- from flwr.common.typing import Fab, Run, RunNotRunningException
42
+ from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
43
+ from flwr.common.typing import Fab, Run
56
44
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
57
45
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
58
46
  CreateNodeRequest,
@@ -67,9 +55,7 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
67
55
  SendNodeHeartbeatRequest,
68
56
  SendNodeHeartbeatResponse,
69
57
  )
70
- from flwr.proto.message_pb2 import ( # pylint: disable=E0611
71
- ConfirmMessageReceivedRequest,
72
- )
58
+ from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
73
59
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
74
60
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
75
61
 
@@ -90,12 +76,15 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
90
76
  adapter_cls: Optional[Union[type[FleetStub], type[GrpcAdapter]]] = None,
91
77
  ) -> Iterator[
92
78
  tuple[
93
- Callable[[], Optional[Message]],
94
- Callable[[Message], None],
79
+ Callable[[], Optional[tuple[Message, ObjectTree]]],
80
+ Callable[[Message, ObjectTree], set[str]],
95
81
  Callable[[], Optional[int]],
96
82
  Callable[[], None],
97
83
  Callable[[int], Run],
98
84
  Callable[[str, int], Fab],
85
+ Callable[[int, str], bytes],
86
+ Callable[[int, str, bytes], None],
87
+ Callable[[int, str], None],
99
88
  ]
100
89
  ]:
101
90
  """Primitives for request/response-based interaction with a server.
@@ -138,6 +127,9 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
138
127
  create_node : Optional[Callable]
139
128
  delete_node : Optional[Callable]
140
129
  get_run : Optional[Callable]
130
+ pull_object : Callable[[str], bytes]
131
+ push_object : Callable[[str, bytes], None]
132
+ confirm_message_received : Callable[[str], None]
141
133
  """
142
134
  if isinstance(root_certificates, str):
143
135
  root_certificates = Path(root_certificates).read_bytes()
@@ -163,20 +155,8 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
163
155
  if adapter_cls is None:
164
156
  adapter_cls = FleetStub
165
157
  stub = adapter_cls(channel)
166
- metadata: Optional[Metadata] = None
167
158
  node: Optional[Node] = None
168
159
 
169
- def _should_giveup_fn(e: Exception) -> bool:
170
- if e.code() == grpc.StatusCode.PERMISSION_DENIED: # type: ignore
171
- raise RunNotRunningException
172
- if e.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore
173
- return False
174
- return True
175
-
176
- # Restrict retries to cases where the status code is UNAVAILABLE
177
- # If the status code is PERMISSION_DENIED, additionally raise RunNotRunningException
178
- retry_invoker.should_giveup = _should_giveup_fn
179
-
180
160
  # Wrap stub
181
161
  _wrap_stub(stub, retry_invoker)
182
162
  ###########################################################################
@@ -249,117 +229,52 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
249
229
  # Cleanup
250
230
  node = None
251
231
 
252
- def receive() -> Optional[Message]:
253
- """Receive next message from server."""
232
+ def receive() -> Optional[tuple[Message, ObjectTree]]:
233
+ """Pull a message with its ObjectTree from SuperLink."""
254
234
  # Get Node
255
235
  if node is None:
256
236
  log(ERROR, "Node instance missing")
257
237
  return None
258
238
 
259
- # Request instructions (message) from server
239
+ # Try to pull a message with its object tree from SuperLink
260
240
  request = PullMessagesRequest(node=node)
261
241
  response: PullMessagesResponse = stub.PullMessages(request=request)
262
242
 
263
- # Get the current Messages
264
- message_proto = (
265
- None if len(response.messages_list) == 0 else response.messages_list[0]
266
- )
243
+ # If no messages are available, return None
244
+ if len(response.messages_list) == 0:
245
+ return None
267
246
 
268
- # Discard the current message if not valid
269
- if message_proto is not None and not (
270
- message_proto.metadata.dst_node_id == node.node_id
271
- ):
272
- message_proto = None
247
+ # Get the current Message and its object tree
248
+ message_proto = response.messages_list[0]
249
+ object_tree = response.message_object_trees[0]
273
250
 
274
251
  # Construct the Message
275
- in_message: Optional[Message] = None
276
-
277
- if message_proto:
278
- msg_id = message_proto.metadata.message_id
279
- run_id = message_proto.metadata.run_id
280
- all_object_contents = pull_objects(
281
- list(response.objects_to_pull[msg_id].object_ids) + [msg_id],
282
- pull_object_fn=make_pull_object_fn_grpc(
283
- pull_object_grpc=stub.PullObject,
284
- node=node,
285
- run_id=run_id,
286
- ),
287
- )
252
+ in_message = message_from_proto(message_proto)
288
253
 
289
- # Confirm that the message has been received
290
- stub.ConfirmMessageReceived(
291
- ConfirmMessageReceivedRequest(
292
- node=node, run_id=run_id, message_object_id=msg_id
293
- )
294
- )
295
-
296
- in_message = cast(
297
- Message, inflate_object_from_contents(msg_id, all_object_contents)
298
- )
299
- # The deflated message doesn't contain the message_id (its own object_id)
300
- # Inject
301
- in_message.metadata.__dict__["_message_id"] = msg_id
302
-
303
- # Remember `metadata` of the in message
304
- nonlocal metadata
305
- metadata = copy(in_message.metadata) if in_message else None
306
-
307
- # Return the message if available
308
- return in_message
254
+ # Return the Message and its object tree
255
+ return in_message, object_tree
309
256
 
310
- def send(message: Message) -> None:
311
- """Send message reply to server."""
257
+ def send(message: Message, object_tree: ObjectTree) -> set[str]:
258
+ """Send the message with its ObjectTree to SuperLink."""
312
259
  # Get Node
313
260
  if node is None:
314
261
  log(ERROR, "Node instance missing")
315
- return
316
-
317
- # Get the metadata of the incoming message
318
- nonlocal metadata
319
- if metadata is None:
320
- log(ERROR, "No current message")
321
- return
262
+ return set()
322
263
 
323
- # Set message_id
324
- message.metadata.__dict__["_message_id"] = message.object_id
325
- # Validate out message
326
- if not validate_out_message(message, metadata):
327
- log(ERROR, "Invalid out message")
328
- return
264
+ # Remove the content from the message if it has
265
+ if message.has_content():
266
+ message = remove_content_from_message(message)
329
267
 
330
- with no_object_id_recompute():
331
- # Get all nested objects
332
- all_objects = get_all_nested_objects(message)
333
- object_tree = get_object_tree(message)
334
-
335
- # Serialize Message
336
- message_proto = message_to_proto(
337
- message=remove_content_from_message(message)
338
- )
339
- request = PushMessagesRequest(
340
- node=node,
341
- messages_list=[message_proto],
342
- message_object_trees=[object_tree],
343
- )
344
- response: PushMessagesResponse = stub.PushMessages(request=request)
345
-
346
- if response.objects_to_push:
347
- objs_to_push = set(
348
- response.objects_to_push[message.object_id].object_ids
349
- )
350
- push_objects(
351
- all_objects,
352
- push_object_fn=make_push_object_fn_grpc(
353
- push_object_grpc=stub.PushObject,
354
- node=node,
355
- run_id=message.metadata.run_id,
356
- ),
357
- object_ids_to_push=objs_to_push,
358
- )
359
- log(DEBUG, "Pushed %s objects to servicer.", len(objs_to_push))
268
+ # Send the message with its ObjectTree to SuperLink
269
+ request = PushMessagesRequest(
270
+ node=node,
271
+ messages_list=[message_to_proto(message)],
272
+ message_object_trees=[object_tree],
273
+ )
274
+ response: PushMessagesResponse = stub.PushMessages(request=request)
360
275
 
361
- # Cleanup
362
- metadata = None
276
+ # Get and return the object IDs to push
277
+ return set(response.objects_to_push)
363
278
 
364
279
  def get_run(run_id: int) -> Run:
365
280
  # Call FleetAPI
@@ -376,9 +291,58 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
376
291
 
377
292
  return Fab(get_fab_response.fab.hash_str, get_fab_response.fab.content)
378
293
 
294
+ def pull_object(run_id: int, object_id: str) -> bytes:
295
+ """Pull the object from the SuperLink."""
296
+ # Check Node
297
+ if node is None:
298
+ raise RuntimeError("Node instance missing")
299
+
300
+ fn = make_pull_object_fn_protobuf(
301
+ pull_object_protobuf=stub.PullObject,
302
+ node=node,
303
+ run_id=run_id,
304
+ )
305
+ return fn(object_id)
306
+
307
+ def push_object(run_id: int, object_id: str, contents: bytes) -> None:
308
+ """Push the object to the SuperLink."""
309
+ # Check Node
310
+ if node is None:
311
+ raise RuntimeError("Node instance missing")
312
+
313
+ fn = make_push_object_fn_protobuf(
314
+ push_object_protobuf=stub.PushObject,
315
+ node=node,
316
+ run_id=run_id,
317
+ )
318
+ fn(object_id, contents)
319
+
320
+ def confirm_message_received(run_id: int, object_id: str) -> None:
321
+ """Confirm that the message has been received."""
322
+ # Check Node
323
+ if node is None:
324
+ raise RuntimeError("Node instance missing")
325
+
326
+ fn = make_confirm_message_received_fn_protobuf(
327
+ confirm_message_received_protobuf=stub.ConfirmMessageReceived,
328
+ node=node,
329
+ run_id=run_id,
330
+ )
331
+ fn(object_id)
332
+
379
333
  try:
380
334
  # Yield methods
381
- yield (receive, send, create_node, delete_node, get_run, get_fab)
335
+ yield (
336
+ receive,
337
+ send,
338
+ create_node,
339
+ delete_node,
340
+ get_run,
341
+ get_fab,
342
+ pull_object,
343
+ push_object,
344
+ confirm_message_received,
345
+ )
382
346
  except Exception as exc: # pylint: disable=broad-except
383
347
  log(ERROR, exc)
384
348
  # Cleanup