flwr-nightly 1.10.0.dev20240624__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (95) hide show
  1. flwr/cli/build.py +18 -4
  2. flwr/cli/config_utils.py +36 -14
  3. flwr/cli/install.py +17 -1
  4. flwr/cli/new/new.py +31 -20
  5. flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
  10. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
  12. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
  13. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
  14. flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
  15. flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
  16. flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
  17. flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
  18. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
  19. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
  20. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
  21. flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
  22. flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
  23. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
  25. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
  26. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
  27. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
  28. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
  29. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
  30. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
  31. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
  32. flwr/cli/run/run.py +135 -51
  33. flwr/client/__init__.py +2 -0
  34. flwr/client/app.py +63 -26
  35. flwr/client/client_app.py +49 -4
  36. flwr/client/grpc_adapter_client/connection.py +3 -2
  37. flwr/client/grpc_client/connection.py +3 -2
  38. flwr/client/grpc_rere_client/connection.py +17 -6
  39. flwr/client/message_handler/message_handler.py +3 -4
  40. flwr/client/node_state.py +60 -10
  41. flwr/client/node_state_tests.py +4 -3
  42. flwr/client/rest_client/connection.py +19 -8
  43. flwr/client/supernode/app.py +60 -21
  44. flwr/client/typing.py +1 -0
  45. flwr/common/config.py +87 -2
  46. flwr/common/constant.py +6 -0
  47. flwr/common/context.py +26 -1
  48. flwr/common/logger.py +38 -0
  49. flwr/common/message.py +0 -17
  50. flwr/common/serde.py +45 -0
  51. flwr/common/telemetry.py +17 -0
  52. flwr/common/typing.py +5 -0
  53. flwr/proto/common_pb2.py +36 -0
  54. flwr/proto/common_pb2.pyi +121 -0
  55. flwr/proto/common_pb2_grpc.py +4 -0
  56. flwr/proto/common_pb2_grpc.pyi +4 -0
  57. flwr/proto/driver_pb2.py +24 -19
  58. flwr/proto/driver_pb2.pyi +21 -1
  59. flwr/proto/exec_pb2.py +16 -11
  60. flwr/proto/exec_pb2.pyi +22 -1
  61. flwr/proto/run_pb2.py +12 -7
  62. flwr/proto/run_pb2.pyi +22 -1
  63. flwr/proto/task_pb2.py +7 -8
  64. flwr/server/__init__.py +2 -0
  65. flwr/server/compat/legacy_context.py +5 -4
  66. flwr/server/driver/grpc_driver.py +82 -140
  67. flwr/server/run_serverapp.py +40 -15
  68. flwr/server/server_app.py +56 -10
  69. flwr/server/serverapp_components.py +52 -0
  70. flwr/server/superlink/driver/driver_servicer.py +18 -3
  71. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  72. flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
  73. flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
  74. flwr/server/superlink/fleet/vce/vce_api.py +149 -122
  75. flwr/server/superlink/state/in_memory_state.py +15 -7
  76. flwr/server/superlink/state/sqlite_state.py +27 -12
  77. flwr/server/superlink/state/state.py +7 -2
  78. flwr/server/superlink/state/utils.py +6 -0
  79. flwr/server/typing.py +2 -0
  80. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  81. flwr/simulation/app.py +52 -36
  82. flwr/simulation/ray_transport/ray_actor.py +15 -19
  83. flwr/simulation/ray_transport/ray_client_proxy.py +33 -13
  84. flwr/simulation/run_simulation.py +237 -66
  85. flwr/superexec/app.py +14 -7
  86. flwr/superexec/deployment.py +186 -0
  87. flwr/superexec/exec_grpc.py +5 -1
  88. flwr/superexec/exec_servicer.py +4 -1
  89. flwr/superexec/executor.py +18 -0
  90. flwr/superexec/simulation.py +151 -0
  91. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
  92. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +95 -88
  93. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
  94. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
  95. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
flwr/cli/run/run.py CHANGED
@@ -14,53 +14,49 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface `run` command."""
16
16
 
17
+ import subprocess
17
18
  import sys
18
- from enum import Enum
19
19
  from logging import DEBUG
20
- from typing import Optional
20
+ from pathlib import Path
21
+ from typing import Any, Dict, List, Optional
21
22
 
22
23
  import typer
23
24
  from typing_extensions import Annotated
24
25
 
25
- from flwr.cli import config_utils
26
- from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
26
+ from flwr.cli.build import build
27
+ from flwr.cli.config_utils import load_and_validate
28
+ from flwr.common.config import parse_config_args
27
29
  from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
28
30
  from flwr.common.logger import log
31
+ from flwr.common.serde import user_config_to_proto
29
32
  from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
30
33
  from flwr.proto.exec_pb2_grpc import ExecStub
31
- from flwr.simulation.run_simulation import _run_simulation
32
-
33
-
34
- class Engine(str, Enum):
35
- """Enum defining the engine to run on."""
36
-
37
- SIMULATION = "simulation"
38
34
 
39
35
 
40
36
  # pylint: disable-next=too-many-locals
41
37
  def run(
42
- engine: Annotated[
43
- Optional[Engine],
44
- typer.Option(
45
- case_sensitive=False,
46
- help="The engine to run FL with (currently only simulation is supported).",
47
- ),
38
+ directory: Annotated[
39
+ Path,
40
+ typer.Argument(help="Path of the Flower project to run"),
41
+ ] = Path("."),
42
+ federation_name: Annotated[
43
+ Optional[str],
44
+ typer.Argument(help="Name of the federation to run the app on"),
48
45
  ] = None,
49
- use_superexec: Annotated[
50
- bool,
46
+ config_overrides: Annotated[
47
+ Optional[List[str]],
51
48
  typer.Option(
52
- case_sensitive=False, help="Use this flag to use the new SuperExec API"
49
+ "--run-config",
50
+ "-c",
51
+ help="Override configuration key-value pairs",
53
52
  ),
54
- ] = False,
53
+ ] = None,
55
54
  ) -> None:
56
55
  """Run Flower project."""
57
- if use_superexec:
58
- _start_superexec_run()
59
- return
60
-
61
56
  typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
62
57
 
63
- config, errors, warnings = config_utils.load_and_validate()
58
+ pyproject_path = directory / "pyproject.toml" if directory else None
59
+ config, errors, warnings = load_and_validate(path=pyproject_path)
64
60
 
65
61
  if config is None:
66
62
  typer.secho(
@@ -82,47 +78,135 @@ def run(
82
78
 
83
79
  typer.secho("Success", fg=typer.colors.GREEN)
84
80
 
85
- server_app_ref = config["flower"]["components"]["serverapp"]
86
- client_app_ref = config["flower"]["components"]["clientapp"]
87
-
88
- if engine is None:
89
- engine = config["flower"]["engine"]["name"]
90
-
91
- if engine == Engine.SIMULATION:
92
- num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]
93
- backend_config = config["flower"]["engine"]["simulation"].get(
94
- "backend_config", None
95
- )
81
+ federation_name = federation_name or config["tool"]["flwr"]["federations"].get(
82
+ "default"
83
+ )
96
84
 
97
- typer.secho("Starting run... ", fg=typer.colors.BLUE)
98
- _run_simulation(
99
- server_app_attr=server_app_ref,
100
- client_app_attr=client_app_ref,
101
- num_supernodes=num_supernodes,
102
- backend_config=backend_config,
85
+ if federation_name is None:
86
+ typer.secho(
87
+ "❌ No federation name was provided and the project's `pyproject.toml` "
88
+ "doesn't declare a default federation (with a SuperExec address or an "
89
+ "`options.num-supernodes` value).",
90
+ fg=typer.colors.RED,
91
+ bold=True,
103
92
  )
104
- else:
93
+ raise typer.Exit(code=1)
94
+
95
+ # Validate the federation exists in the configuration
96
+ federation = config["tool"]["flwr"]["federations"].get(federation_name)
97
+ if federation is None:
98
+ available_feds = {
99
+ fed for fed in config["tool"]["flwr"]["federations"] if fed != "default"
100
+ }
105
101
  typer.secho(
106
- f"Engine '{engine}' is not yet supported in `flwr run`",
102
+ f" There is no `{federation_name}` federation declared in the "
103
+ "`pyproject.toml`.\n The following federations were found:\n\n"
104
+ + "\n".join(available_feds),
107
105
  fg=typer.colors.RED,
108
106
  bold=True,
109
107
  )
108
+ raise typer.Exit(code=1)
109
+
110
+ if "address" in federation:
111
+ _run_with_superexec(federation, directory, config_overrides)
112
+ else:
113
+ _run_without_superexec(directory, federation, federation_name, config_overrides)
114
+
110
115
 
116
+ def _run_with_superexec(
117
+ federation: Dict[str, str],
118
+ directory: Optional[Path],
119
+ config_overrides: Optional[List[str]],
120
+ ) -> None:
111
121
 
112
- def _start_superexec_run() -> None:
113
122
  def on_channel_state_change(channel_connectivity: str) -> None:
114
123
  """Log channel connectivity."""
115
124
  log(DEBUG, channel_connectivity)
116
125
 
126
+ insecure_str = federation.get("insecure")
127
+ if root_certificates := federation.get("root-certificates"):
128
+ root_certificates_bytes = Path(root_certificates).read_bytes()
129
+ if insecure := bool(insecure_str):
130
+ typer.secho(
131
+ "❌ `root_certificates` were provided but the `insecure` parameter"
132
+ "is set to `True`.",
133
+ fg=typer.colors.RED,
134
+ bold=True,
135
+ )
136
+ raise typer.Exit(code=1)
137
+ else:
138
+ root_certificates_bytes = None
139
+ if insecure_str is None:
140
+ typer.secho(
141
+ "❌ To disable TLS, set `insecure = true` in `pyproject.toml`.",
142
+ fg=typer.colors.RED,
143
+ bold=True,
144
+ )
145
+ raise typer.Exit(code=1)
146
+ if not (insecure := bool(insecure_str)):
147
+ typer.secho(
148
+ "❌ No certificate were given yet `insecure` is set to `False`.",
149
+ fg=typer.colors.RED,
150
+ bold=True,
151
+ )
152
+ raise typer.Exit(code=1)
153
+
117
154
  channel = create_channel(
118
- server_address=SUPEREXEC_DEFAULT_ADDRESS,
119
- insecure=True,
120
- root_certificates=None,
155
+ server_address=federation["address"],
156
+ insecure=insecure,
157
+ root_certificates=root_certificates_bytes,
121
158
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
122
159
  interceptors=None,
123
160
  )
124
161
  channel.subscribe(on_channel_state_change)
125
162
  stub = ExecStub(channel)
126
163
 
127
- req = StartRunRequest()
128
- stub.StartRun(req)
164
+ fab_path = build(directory)
165
+
166
+ req = StartRunRequest(
167
+ fab_file=Path(fab_path).read_bytes(),
168
+ override_config=user_config_to_proto(
169
+ parse_config_args(config_overrides, separator=",")
170
+ ),
171
+ )
172
+ res = stub.StartRun(req)
173
+ typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
174
+
175
+
176
+ def _run_without_superexec(
177
+ app_path: Optional[Path],
178
+ federation: Dict[str, Any],
179
+ federation_name: str,
180
+ config_overrides: Optional[List[str]],
181
+ ) -> None:
182
+ try:
183
+ num_supernodes = federation["options"]["num-supernodes"]
184
+ except KeyError as err:
185
+ typer.secho(
186
+ "❌ The project's `pyproject.toml` needs to declare the number of"
187
+ " SuperNodes in the simulation. To simulate 10 SuperNodes,"
188
+ " use the following notation:\n\n"
189
+ f"[tool.flwr.federations.{federation_name}]\n"
190
+ "options.num-supernodes = 10\n",
191
+ fg=typer.colors.RED,
192
+ bold=True,
193
+ )
194
+ raise typer.Exit(code=1) from err
195
+
196
+ command = [
197
+ "flower-simulation",
198
+ "--app",
199
+ f"{app_path}",
200
+ "--num-supernodes",
201
+ f"{num_supernodes}",
202
+ ]
203
+
204
+ if config_overrides:
205
+ command.extend(["--run-config", f"{','.join(config_overrides)}"])
206
+
207
+ # Run the simulation
208
+ subprocess.run(
209
+ command,
210
+ check=True,
211
+ text=True,
212
+ )
flwr/client/__init__.py CHANGED
@@ -23,11 +23,13 @@ from .numpy_client import NumPyClient as NumPyClient
23
23
  from .supernode import run_client_app as run_client_app
24
24
  from .supernode import run_supernode as run_supernode
25
25
  from .typing import ClientFn as ClientFn
26
+ from .typing import ClientFnExt as ClientFnExt
26
27
 
27
28
  __all__ = [
28
29
  "Client",
29
30
  "ClientApp",
30
31
  "ClientFn",
32
+ "ClientFnExt",
31
33
  "NumPyClient",
32
34
  "mod",
33
35
  "run_client_app",
flwr/client/app.py CHANGED
@@ -18,7 +18,8 @@ import signal
18
18
  import sys
19
19
  import time
20
20
  from dataclasses import dataclass
21
- from logging import DEBUG, ERROR, INFO, WARN
21
+ from logging import ERROR, INFO, WARN
22
+ from pathlib import Path
22
23
  from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union
23
24
 
24
25
  from cryptography.hazmat.primitives.asymmetric import ec
@@ -26,8 +27,8 @@ from grpc import RpcError
26
27
 
27
28
  from flwr.client.client import Client
28
29
  from flwr.client.client_app import ClientApp, LoadClientAppError
29
- from flwr.client.typing import ClientFn
30
- from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
30
+ from flwr.client.typing import ClientFnExt
31
+ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
31
32
  from flwr.common.address import parse_address
32
33
  from flwr.common.constant import (
33
34
  MISSING_EXTRA_REST,
@@ -41,6 +42,7 @@ from flwr.common.constant import (
41
42
  from flwr.common.logger import log, warn_deprecated_feature
42
43
  from flwr.common.message import Error
43
44
  from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
45
+ from flwr.common.typing import Run, UserConfig
44
46
 
45
47
  from .grpc_adapter_client.connection import grpc_adapter
46
48
  from .grpc_client.connection import grpc_connection
@@ -51,7 +53,7 @@ from .numpy_client import NumPyClient
51
53
 
52
54
 
53
55
  def _check_actionable_client(
54
- client: Optional[Client], client_fn: Optional[ClientFn]
56
+ client: Optional[Client], client_fn: Optional[ClientFnExt]
55
57
  ) -> None:
56
58
  if client_fn is None and client is None:
57
59
  raise ValueError(
@@ -72,7 +74,7 @@ def _check_actionable_client(
72
74
  def start_client(
73
75
  *,
74
76
  server_address: str,
75
- client_fn: Optional[ClientFn] = None,
77
+ client_fn: Optional[ClientFnExt] = None,
76
78
  client: Optional[Client] = None,
77
79
  grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
78
80
  root_certificates: Optional[Union[bytes, str]] = None,
@@ -92,7 +94,7 @@ def start_client(
92
94
  The IPv4 or IPv6 address of the server. If the Flower
93
95
  server runs on the same machine on port 8080, then `server_address`
94
96
  would be `"[::]:8080"`.
95
- client_fn : Optional[ClientFn]
97
+ client_fn : Optional[ClientFnExt]
96
98
  A callable that instantiates a Client. (default: None)
97
99
  client : Optional[flwr.client.Client]
98
100
  An implementation of the abstract base
@@ -136,8 +138,8 @@ def start_client(
136
138
 
137
139
  Starting an SSL-enabled gRPC client using system certificates:
138
140
 
139
- >>> def client_fn(cid: str):
140
- >>> return FlowerClient()
141
+ >>> def client_fn(context: Context):
142
+ >>> return FlowerClient().to_client()
141
143
  >>>
142
144
  >>> start_client(
143
145
  >>> server_address=localhost:8080,
@@ -158,6 +160,7 @@ def start_client(
158
160
  event(EventType.START_CLIENT_ENTER)
159
161
  _start_client_internal(
160
162
  server_address=server_address,
163
+ node_config={},
161
164
  load_client_app_fn=None,
162
165
  client_fn=client_fn,
163
166
  client=client,
@@ -179,8 +182,9 @@ def start_client(
179
182
  def _start_client_internal(
180
183
  *,
181
184
  server_address: str,
185
+ node_config: UserConfig,
182
186
  load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
183
- client_fn: Optional[ClientFn] = None,
187
+ client_fn: Optional[ClientFnExt] = None,
184
188
  client: Optional[Client] = None,
185
189
  grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
186
190
  root_certificates: Optional[Union[bytes, str]] = None,
@@ -191,6 +195,7 @@ def _start_client_internal(
191
195
  ] = None,
192
196
  max_retries: Optional[int] = None,
193
197
  max_wait_time: Optional[float] = None,
198
+ flwr_path: Optional[Path] = None,
194
199
  ) -> None:
195
200
  """Start a Flower client node which connects to a Flower server.
196
201
 
@@ -200,9 +205,11 @@ def _start_client_internal(
200
205
  The IPv4 or IPv6 address of the server. If the Flower
201
206
  server runs on the same machine on port 8080, then `server_address`
202
207
  would be `"[::]:8080"`.
208
+ node_config: UserConfig
209
+ The configuration of the node.
203
210
  load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None)
204
211
  A function that can be used to load a `ClientApp` instance.
205
- client_fn : Optional[ClientFn]
212
+ client_fn : Optional[ClientFnExt]
206
213
  A callable that instantiates a Client. (default: None)
207
214
  client : Optional[flwr.client.Client]
208
215
  An implementation of the abstract base
@@ -234,6 +241,8 @@ def _start_client_internal(
234
241
  The maximum duration before the client stops trying to
235
242
  connect to the server in case of connection error.
236
243
  If set to None, there is no limit to the total time.
244
+ flwr_path: Optional[Path] (default: None)
245
+ The fully resolved path containing installed Flower Apps.
237
246
  """
238
247
  if insecure is None:
239
248
  insecure = root_certificates is None
@@ -244,7 +253,7 @@ def _start_client_internal(
244
253
  if client_fn is None:
245
254
  # Wrap `Client` instance in `client_fn`
246
255
  def single_client_factory(
247
- cid: str, # pylint: disable=unused-argument
256
+ context: Context, # pylint: disable=unused-argument
248
257
  ) -> Client:
249
258
  if client is None: # Added this to keep mypy happy
250
259
  raise ValueError(
@@ -285,7 +294,7 @@ def _start_client_internal(
285
294
  log(WARN, "Connection attempt failed, retrying...")
286
295
  else:
287
296
  log(
288
- DEBUG,
297
+ WARN,
289
298
  "Connection attempt failed, retrying in %.2f seconds",
290
299
  retry_state.actual_wait,
291
300
  )
@@ -293,7 +302,7 @@ def _start_client_internal(
293
302
  retry_invoker = RetryInvoker(
294
303
  wait_gen_factory=exponential,
295
304
  recoverable_exceptions=connection_error_type,
296
- max_tries=max_retries,
305
+ max_tries=max_retries + 1 if max_retries is not None else None,
297
306
  max_time=max_wait_time,
298
307
  on_giveup=lambda retry_state: (
299
308
  log(
@@ -309,9 +318,10 @@ def _start_client_internal(
309
318
  on_backoff=_on_backoff,
310
319
  )
311
320
 
312
- node_state = NodeState()
313
- # run_id -> (fab_id, fab_version)
314
- run_info: Dict[int, Tuple[str, str]] = {}
321
+ # NodeState gets initialized when the first connection is established
322
+ node_state: Optional[NodeState] = None
323
+
324
+ runs: Dict[int, Run] = {}
315
325
 
316
326
  while not app_state_tracker.interrupt:
317
327
  sleep_duration: int = 0
@@ -325,9 +335,31 @@ def _start_client_internal(
325
335
  ) as conn:
326
336
  receive, send, create_node, delete_node, get_run = conn
327
337
 
328
- # Register node
329
- if create_node is not None:
330
- create_node() # pylint: disable=not-callable
338
+ # Register node when connecting the first time
339
+ if node_state is None:
340
+ if create_node is None:
341
+ if transport not in ["grpc-bidi", None]:
342
+ raise NotImplementedError(
343
+ "All transports except `grpc-bidi` require "
344
+ "an implementation for `create_node()`.'"
345
+ )
346
+ # gRPC-bidi doesn't have the concept of node_id,
347
+ # so we set it to -1
348
+ node_state = NodeState(
349
+ node_id=-1,
350
+ node_config={},
351
+ )
352
+ else:
353
+ # Call create_node fn to register node
354
+ node_id: Optional[int] = ( # pylint: disable=assignment-from-none
355
+ create_node()
356
+ ) # pylint: disable=not-callable
357
+ if node_id is None:
358
+ raise ValueError("Node registration failed")
359
+ node_state = NodeState(
360
+ node_id=node_id,
361
+ node_config=node_config,
362
+ )
331
363
 
332
364
  app_state_tracker.register_signal_handler()
333
365
  while not app_state_tracker.interrupt:
@@ -361,15 +393,17 @@ def _start_client_internal(
361
393
 
362
394
  # Get run info
363
395
  run_id = message.metadata.run_id
364
- if run_id not in run_info:
396
+ if run_id not in runs:
365
397
  if get_run is not None:
366
- run_info[run_id] = get_run(run_id)
398
+ runs[run_id] = get_run(run_id)
367
399
  # If get_run is None, i.e., in grpc-bidi mode
368
400
  else:
369
- run_info[run_id] = ("", "")
401
+ runs[run_id] = Run(run_id, "", "", {})
370
402
 
371
403
  # Register context for this run
372
- node_state.register_context(run_id=run_id)
404
+ node_state.register_context(
405
+ run_id=run_id, run=runs[run_id], flwr_path=flwr_path
406
+ )
373
407
 
374
408
  # Retrieve context for this run
375
409
  context = node_state.retrieve_context(run_id=run_id)
@@ -383,7 +417,10 @@ def _start_client_internal(
383
417
  # Handle app loading and task message
384
418
  try:
385
419
  # Load ClientApp instance
386
- client_app: ClientApp = load_client_app_fn(*run_info[run_id])
420
+ run: Run = runs[run_id]
421
+ client_app: ClientApp = load_client_app_fn(
422
+ run.fab_id, run.fab_version
423
+ )
387
424
 
388
425
  # Execute ClientApp
389
426
  reply_message = client_app(message=message, context=context)
@@ -566,9 +603,9 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
566
603
  Tuple[
567
604
  Callable[[], Optional[Message]],
568
605
  Callable[[Message], None],
606
+ Optional[Callable[[], Optional[int]]],
569
607
  Optional[Callable[[], None]],
570
- Optional[Callable[[], None]],
571
- Optional[Callable[[int], Tuple[str, str]]],
608
+ Optional[Callable[[int], Run]],
572
609
  ]
573
610
  ],
574
611
  ],
flwr/client/client_app.py CHANGED
@@ -15,19 +15,62 @@
15
15
  """Flower ClientApp."""
16
16
 
17
17
 
18
+ import inspect
18
19
  from typing import Callable, List, Optional
19
20
 
21
+ from flwr.client.client import Client
20
22
  from flwr.client.message_handler.message_handler import (
21
23
  handle_legacy_message_from_msgtype,
22
24
  )
23
25
  from flwr.client.mod.utils import make_ffn
24
- from flwr.client.typing import ClientFn, Mod
26
+ from flwr.client.typing import ClientFnExt, Mod
25
27
  from flwr.common import Context, Message, MessageType
26
- from flwr.common.logger import warn_preview_feature
28
+ from flwr.common.logger import warn_deprecated_feature, warn_preview_feature
27
29
 
28
30
  from .typing import ClientAppCallable
29
31
 
30
32
 
33
+ def _alert_erroneous_client_fn() -> None:
34
+ raise ValueError(
35
+ "A `ClientApp` cannot make use of a `client_fn` that does "
36
+ "not have a signature in the form: `def client_fn(context: "
37
+ "Context)`. You can import the `Context` like this: "
38
+ "`from flwr.common import Context`"
39
+ )
40
+
41
+
42
+ def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
43
+ client_fn_args = inspect.signature(client_fn).parameters
44
+ first_arg = list(client_fn_args.keys())[0]
45
+
46
+ if len(client_fn_args) != 1:
47
+ _alert_erroneous_client_fn()
48
+
49
+ first_arg_type = client_fn_args[first_arg].annotation
50
+
51
+ if first_arg_type is str or first_arg == "cid":
52
+ # Warn previous signature for `client_fn` seems to be used
53
+ warn_deprecated_feature(
54
+ "`client_fn` now expects a signature `def client_fn(context: Context)`."
55
+ "The provided `client_fn` has signature: "
56
+ f"{dict(client_fn_args.items())}. You can import the `Context` like this:"
57
+ " `from flwr.common import Context`"
58
+ )
59
+
60
+ # Wrap depcreated client_fn inside a function with the expected signature
61
+ def adaptor_fn(
62
+ context: Context,
63
+ ) -> Client: # pylint: disable=unused-argument
64
+ # if patition-id is defined, pass it. Else pass node_id that should
65
+ # always be defined during Context init.
66
+ cid = context.node_config.get("partition-id", context.node_id)
67
+ return client_fn(str(cid)) # type: ignore
68
+
69
+ return adaptor_fn
70
+
71
+ return client_fn
72
+
73
+
31
74
  class ClientAppException(Exception):
32
75
  """Exception raised when an exception is raised while executing a ClientApp."""
33
76
 
@@ -48,7 +91,7 @@ class ClientApp:
48
91
  >>> class FlowerClient(NumPyClient):
49
92
  >>> # ...
50
93
  >>>
51
- >>> def client_fn(cid):
94
+ >>> def client_fn(context: Context):
52
95
  >>> return FlowerClient().to_client()
53
96
  >>>
54
97
  >>> app = ClientApp(client_fn)
@@ -65,7 +108,7 @@ class ClientApp:
65
108
 
66
109
  def __init__(
67
110
  self,
68
- client_fn: Optional[ClientFn] = None, # Only for backward compatibility
111
+ client_fn: Optional[ClientFnExt] = None, # Only for backward compatibility
69
112
  mods: Optional[List[Mod]] = None,
70
113
  ) -> None:
71
114
  self._mods: List[Mod] = mods if mods is not None else []
@@ -74,6 +117,8 @@ class ClientApp:
74
117
  self._call: Optional[ClientAppCallable] = None
75
118
  if client_fn is not None:
76
119
 
120
+ client_fn = _inspect_maybe_adapt_client_fn_signature(client_fn)
121
+
77
122
  def ffn(
78
123
  message: Message,
79
124
  context: Context,
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
27
27
  from flwr.common.logger import log
28
28
  from flwr.common.message import Message
29
29
  from flwr.common.retry_invoker import RetryInvoker
30
+ from flwr.common.typing import Run
30
31
 
31
32
 
32
33
  @contextmanager
@@ -43,9 +44,9 @@ def grpc_adapter( # pylint: disable=R0913
43
44
  Tuple[
44
45
  Callable[[], Optional[Message]],
45
46
  Callable[[Message], None],
47
+ Optional[Callable[[], Optional[int]]],
46
48
  Optional[Callable[[], None]],
47
- Optional[Callable[[], None]],
48
- Optional[Callable[[int], Tuple[str, str]]],
49
+ Optional[Callable[[int], Run]],
49
50
  ]
50
51
  ]:
51
52
  """Primitives for request/response-based interaction with a server via GrpcAdapter.
@@ -38,6 +38,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
38
38
  from flwr.common.grpc import create_channel
39
39
  from flwr.common.logger import log
40
40
  from flwr.common.retry_invoker import RetryInvoker
41
+ from flwr.common.typing import Run
41
42
  from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
42
43
  ClientMessage,
43
44
  Reason,
@@ -71,9 +72,9 @@ def grpc_connection( # pylint: disable=R0913, R0915
71
72
  Tuple[
72
73
  Callable[[], Optional[Message]],
73
74
  Callable[[Message], None],
75
+ Optional[Callable[[], Optional[int]]],
74
76
  Optional[Callable[[], None]],
75
- Optional[Callable[[], None]],
76
- Optional[Callable[[int], Tuple[str, str]]],
77
+ Optional[Callable[[int], Run]],
77
78
  ]
78
79
  ]:
79
80
  """Establish a gRPC connection to a gRPC server.
@@ -40,7 +40,12 @@ from flwr.common.grpc import create_channel
40
40
  from flwr.common.logger import log
41
41
  from flwr.common.message import Message, Metadata
42
42
  from flwr.common.retry_invoker import RetryInvoker
43
- from flwr.common.serde import message_from_taskins, message_to_taskres
43
+ from flwr.common.serde import (
44
+ message_from_taskins,
45
+ message_to_taskres,
46
+ user_config_from_proto,
47
+ )
48
+ from flwr.common.typing import Run
44
49
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
45
50
  CreateNodeRequest,
46
51
  DeleteNodeRequest,
@@ -78,9 +83,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
78
83
  Tuple[
79
84
  Callable[[], Optional[Message]],
80
85
  Callable[[Message], None],
86
+ Optional[Callable[[], Optional[int]]],
81
87
  Optional[Callable[[], None]],
82
- Optional[Callable[[], None]],
83
- Optional[Callable[[int], Tuple[str, str]]],
88
+ Optional[Callable[[int], Run]],
84
89
  ]
85
90
  ]:
86
91
  """Primitives for request/response-based interaction with a server.
@@ -175,7 +180,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
175
180
  if not ping_stop_event.is_set():
176
181
  ping_stop_event.wait(next_interval)
177
182
 
178
- def create_node() -> None:
183
+ def create_node() -> Optional[int]:
179
184
  """Set create_node."""
180
185
  # Call FleetAPI
181
186
  create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
@@ -188,6 +193,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
188
193
  nonlocal node, ping_thread
189
194
  node = cast(Node, create_node_response.node)
190
195
  ping_thread = start_ping_loop(ping, ping_stop_event)
196
+ return node.node_id
191
197
 
192
198
  def delete_node() -> None:
193
199
  """Set delete_node."""
@@ -266,7 +272,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
266
272
  # Cleanup
267
273
  metadata = None
268
274
 
269
- def get_run(run_id: int) -> Tuple[str, str]:
275
+ def get_run(run_id: int) -> Run:
270
276
  # Call FleetAPI
271
277
  get_run_request = GetRunRequest(run_id=run_id)
272
278
  get_run_response: GetRunResponse = retry_invoker.invoke(
@@ -275,7 +281,12 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
275
281
  )
276
282
 
277
283
  # Return fab_id and fab_version
278
- return get_run_response.run.fab_id, get_run_response.run.fab_version
284
+ return Run(
285
+ run_id,
286
+ get_run_response.run.fab_id,
287
+ get_run_response.run.fab_version,
288
+ user_config_from_proto(get_run_response.run.override_config),
289
+ )
279
290
 
280
291
  try:
281
292
  # Yield methods