flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (92) hide show
  1. flwr/cli/build.py +16 -2
  2. flwr/cli/config_utils.py +36 -14
  3. flwr/cli/install.py +17 -1
  4. flwr/cli/new/new.py +31 -20
  5. flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
  10. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
  12. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
  13. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
  14. flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
  15. flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
  16. flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
  17. flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
  18. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
  19. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
  20. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
  21. flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
  22. flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
  23. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
  25. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
  26. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
  27. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
  28. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
  29. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
  30. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
  31. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
  32. flwr/cli/run/run.py +128 -53
  33. flwr/client/app.py +56 -24
  34. flwr/client/client_app.py +28 -8
  35. flwr/client/grpc_adapter_client/connection.py +3 -2
  36. flwr/client/grpc_client/connection.py +3 -2
  37. flwr/client/grpc_rere_client/connection.py +17 -6
  38. flwr/client/message_handler/message_handler.py +1 -1
  39. flwr/client/node_state.py +59 -12
  40. flwr/client/node_state_tests.py +4 -3
  41. flwr/client/rest_client/connection.py +19 -8
  42. flwr/client/supernode/app.py +55 -24
  43. flwr/client/typing.py +2 -2
  44. flwr/common/config.py +87 -2
  45. flwr/common/constant.py +3 -0
  46. flwr/common/context.py +24 -9
  47. flwr/common/logger.py +25 -0
  48. flwr/common/serde.py +45 -0
  49. flwr/common/telemetry.py +17 -0
  50. flwr/common/typing.py +5 -0
  51. flwr/proto/common_pb2.py +36 -0
  52. flwr/proto/common_pb2.pyi +121 -0
  53. flwr/proto/common_pb2_grpc.py +4 -0
  54. flwr/proto/common_pb2_grpc.pyi +4 -0
  55. flwr/proto/driver_pb2.py +24 -19
  56. flwr/proto/driver_pb2.pyi +21 -1
  57. flwr/proto/exec_pb2.py +16 -11
  58. flwr/proto/exec_pb2.pyi +22 -1
  59. flwr/proto/run_pb2.py +12 -7
  60. flwr/proto/run_pb2.pyi +22 -1
  61. flwr/proto/task_pb2.py +7 -8
  62. flwr/server/__init__.py +2 -0
  63. flwr/server/compat/legacy_context.py +5 -4
  64. flwr/server/driver/grpc_driver.py +82 -140
  65. flwr/server/run_serverapp.py +40 -15
  66. flwr/server/server_app.py +56 -10
  67. flwr/server/serverapp_components.py +52 -0
  68. flwr/server/superlink/driver/driver_servicer.py +18 -3
  69. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  70. flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
  71. flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
  72. flwr/server/superlink/fleet/vce/vce_api.py +149 -117
  73. flwr/server/superlink/state/in_memory_state.py +11 -3
  74. flwr/server/superlink/state/sqlite_state.py +23 -8
  75. flwr/server/superlink/state/state.py +7 -2
  76. flwr/server/typing.py +2 -0
  77. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  78. flwr/simulation/app.py +4 -3
  79. flwr/simulation/ray_transport/ray_actor.py +15 -19
  80. flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
  81. flwr/simulation/run_simulation.py +237 -66
  82. flwr/superexec/app.py +14 -7
  83. flwr/superexec/deployment.py +110 -33
  84. flwr/superexec/exec_grpc.py +5 -1
  85. flwr/superexec/exec_servicer.py +4 -1
  86. flwr/superexec/executor.py +18 -0
  87. flwr/superexec/simulation.py +151 -0
  88. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
  89. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +92 -86
  90. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
  91. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
  92. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
flwr/cli/run/run.py CHANGED
@@ -14,59 +14,49 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface `run` command."""
16
16
 
17
+ import subprocess
17
18
  import sys
18
- from enum import Enum
19
19
  from logging import DEBUG
20
20
  from pathlib import Path
21
- from typing import Optional
21
+ from typing import Any, Dict, List, Optional
22
22
 
23
23
  import typer
24
24
  from typing_extensions import Annotated
25
25
 
26
- from flwr.cli import config_utils
27
26
  from flwr.cli.build import build
28
- from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
27
+ from flwr.cli.config_utils import load_and_validate
28
+ from flwr.common.config import parse_config_args
29
29
  from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
30
30
  from flwr.common.logger import log
31
+ from flwr.common.serde import user_config_to_proto
31
32
  from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
32
33
  from flwr.proto.exec_pb2_grpc import ExecStub
33
- from flwr.simulation.run_simulation import _run_simulation
34
-
35
-
36
- class Engine(str, Enum):
37
- """Enum defining the engine to run on."""
38
-
39
- SIMULATION = "simulation"
40
34
 
41
35
 
42
36
  # pylint: disable-next=too-many-locals
43
37
  def run(
44
- engine: Annotated[
45
- Optional[Engine],
46
- typer.Option(
47
- case_sensitive=False,
48
- help="The engine to run FL with (currently only simulation is supported).",
49
- ),
38
+ directory: Annotated[
39
+ Path,
40
+ typer.Argument(help="Path of the Flower project to run"),
41
+ ] = Path("."),
42
+ federation_name: Annotated[
43
+ Optional[str],
44
+ typer.Argument(help="Name of the federation to run the app on"),
50
45
  ] = None,
51
- use_superexec: Annotated[
52
- bool,
46
+ config_overrides: Annotated[
47
+ Optional[List[str]],
53
48
  typer.Option(
54
- case_sensitive=False, help="Use this flag to use the new SuperExec API"
49
+ "--run-config",
50
+ "-c",
51
+ help="Override configuration key-value pairs",
55
52
  ),
56
- ] = False,
57
- directory: Annotated[
58
- Optional[Path],
59
- typer.Option(help="Path of the Flower project to run"),
60
53
  ] = None,
61
54
  ) -> None:
62
55
  """Run Flower project."""
63
- if use_superexec:
64
- _start_superexec_run(directory)
65
- return
66
-
67
56
  typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
68
57
 
69
- config, errors, warnings = config_utils.load_and_validate()
58
+ pyproject_path = directory / "pyproject.toml" if directory else None
59
+ config, errors, warnings = load_and_validate(path=pyproject_path)
70
60
 
71
61
  if config is None:
72
62
  typer.secho(
@@ -88,42 +78,83 @@ def run(
88
78
 
89
79
  typer.secho("Success", fg=typer.colors.GREEN)
90
80
 
91
- server_app_ref = config["flower"]["components"]["serverapp"]
92
- client_app_ref = config["flower"]["components"]["clientapp"]
93
-
94
- if engine is None:
95
- engine = config["flower"]["engine"]["name"]
96
-
97
- if engine == Engine.SIMULATION:
98
- num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]
99
- backend_config = config["flower"]["engine"]["simulation"].get(
100
- "backend_config", None
101
- )
81
+ federation_name = federation_name or config["tool"]["flwr"]["federations"].get(
82
+ "default"
83
+ )
102
84
 
103
- typer.secho("Starting run... ", fg=typer.colors.BLUE)
104
- _run_simulation(
105
- server_app_attr=server_app_ref,
106
- client_app_attr=client_app_ref,
107
- num_supernodes=num_supernodes,
108
- backend_config=backend_config,
85
+ if federation_name is None:
86
+ typer.secho(
87
+ "❌ No federation name was provided and the project's `pyproject.toml` "
88
+ "doesn't declare a default federation (with a SuperExec address or an "
89
+ "`options.num-supernodes` value).",
90
+ fg=typer.colors.RED,
91
+ bold=True,
109
92
  )
110
- else:
93
+ raise typer.Exit(code=1)
94
+
95
+ # Validate the federation exists in the configuration
96
+ federation = config["tool"]["flwr"]["federations"].get(federation_name)
97
+ if federation is None:
98
+ available_feds = {
99
+ fed for fed in config["tool"]["flwr"]["federations"] if fed != "default"
100
+ }
111
101
  typer.secho(
112
- f"Engine '{engine}' is not yet supported in `flwr run`",
102
+ f" There is no `{federation_name}` federation declared in the "
103
+ "`pyproject.toml`.\n The following federations were found:\n\n"
104
+ + "\n".join(available_feds),
113
105
  fg=typer.colors.RED,
114
106
  bold=True,
115
107
  )
108
+ raise typer.Exit(code=1)
116
109
 
110
+ if "address" in federation:
111
+ _run_with_superexec(federation, directory, config_overrides)
112
+ else:
113
+ _run_without_superexec(directory, federation, federation_name, config_overrides)
114
+
115
+
116
+ def _run_with_superexec(
117
+ federation: Dict[str, str],
118
+ directory: Optional[Path],
119
+ config_overrides: Optional[List[str]],
120
+ ) -> None:
117
121
 
118
- def _start_superexec_run(directory: Optional[Path]) -> None:
119
122
  def on_channel_state_change(channel_connectivity: str) -> None:
120
123
  """Log channel connectivity."""
121
124
  log(DEBUG, channel_connectivity)
122
125
 
126
+ insecure_str = federation.get("insecure")
127
+ if root_certificates := federation.get("root-certificates"):
128
+ root_certificates_bytes = Path(root_certificates).read_bytes()
129
+ if insecure := bool(insecure_str):
130
+ typer.secho(
131
+ "❌ `root_certificates` were provided but the `insecure` parameter"
132
+ "is set to `True`.",
133
+ fg=typer.colors.RED,
134
+ bold=True,
135
+ )
136
+ raise typer.Exit(code=1)
137
+ else:
138
+ root_certificates_bytes = None
139
+ if insecure_str is None:
140
+ typer.secho(
141
+ "❌ To disable TLS, set `insecure = true` in `pyproject.toml`.",
142
+ fg=typer.colors.RED,
143
+ bold=True,
144
+ )
145
+ raise typer.Exit(code=1)
146
+ if not (insecure := bool(insecure_str)):
147
+ typer.secho(
148
+ "❌ No certificate were given yet `insecure` is set to `False`.",
149
+ fg=typer.colors.RED,
150
+ bold=True,
151
+ )
152
+ raise typer.Exit(code=1)
153
+
123
154
  channel = create_channel(
124
- server_address=SUPEREXEC_DEFAULT_ADDRESS,
125
- insecure=True,
126
- root_certificates=None,
155
+ server_address=federation["address"],
156
+ insecure=insecure,
157
+ root_certificates=root_certificates_bytes,
127
158
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
128
159
  interceptors=None,
129
160
  )
@@ -132,6 +163,50 @@ def _start_superexec_run(directory: Optional[Path]) -> None:
132
163
 
133
164
  fab_path = build(directory)
134
165
 
135
- req = StartRunRequest(fab_file=Path(fab_path).read_bytes())
166
+ req = StartRunRequest(
167
+ fab_file=Path(fab_path).read_bytes(),
168
+ override_config=user_config_to_proto(
169
+ parse_config_args(config_overrides, separator=",")
170
+ ),
171
+ )
136
172
  res = stub.StartRun(req)
137
173
  typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
174
+
175
+
176
+ def _run_without_superexec(
177
+ app_path: Optional[Path],
178
+ federation: Dict[str, Any],
179
+ federation_name: str,
180
+ config_overrides: Optional[List[str]],
181
+ ) -> None:
182
+ try:
183
+ num_supernodes = federation["options"]["num-supernodes"]
184
+ except KeyError as err:
185
+ typer.secho(
186
+ "❌ The project's `pyproject.toml` needs to declare the number of"
187
+ " SuperNodes in the simulation. To simulate 10 SuperNodes,"
188
+ " use the following notation:\n\n"
189
+ f"[tool.flwr.federations.{federation_name}]\n"
190
+ "options.num-supernodes = 10\n",
191
+ fg=typer.colors.RED,
192
+ bold=True,
193
+ )
194
+ raise typer.Exit(code=1) from err
195
+
196
+ command = [
197
+ "flower-simulation",
198
+ "--app",
199
+ f"{app_path}",
200
+ "--num-supernodes",
201
+ f"{num_supernodes}",
202
+ ]
203
+
204
+ if config_overrides:
205
+ command.extend(["--run-config", f"{','.join(config_overrides)}"])
206
+
207
+ # Run the simulation
208
+ subprocess.run(
209
+ command,
210
+ check=True,
211
+ text=True,
212
+ )
flwr/client/app.py CHANGED
@@ -18,7 +18,8 @@ import signal
18
18
  import sys
19
19
  import time
20
20
  from dataclasses import dataclass
21
- from logging import DEBUG, ERROR, INFO, WARN
21
+ from logging import ERROR, INFO, WARN
22
+ from pathlib import Path
22
23
  from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union
23
24
 
24
25
  from cryptography.hazmat.primitives.asymmetric import ec
@@ -27,7 +28,7 @@ from grpc import RpcError
27
28
  from flwr.client.client import Client
28
29
  from flwr.client.client_app import ClientApp, LoadClientAppError
29
30
  from flwr.client.typing import ClientFnExt
30
- from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
31
+ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
31
32
  from flwr.common.address import parse_address
32
33
  from flwr.common.constant import (
33
34
  MISSING_EXTRA_REST,
@@ -41,6 +42,7 @@ from flwr.common.constant import (
41
42
  from flwr.common.logger import log, warn_deprecated_feature
42
43
  from flwr.common.message import Error
43
44
  from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
45
+ from flwr.common.typing import Run, UserConfig
44
46
 
45
47
  from .grpc_adapter_client.connection import grpc_adapter
46
48
  from .grpc_client.connection import grpc_connection
@@ -136,8 +138,8 @@ def start_client(
136
138
 
137
139
  Starting an SSL-enabled gRPC client using system certificates:
138
140
 
139
- >>> def client_fn(node_id: int, partition_id: Optional[int]):
140
- >>> return FlowerClient()
141
+ >>> def client_fn(context: Context):
142
+ >>> return FlowerClient().to_client()
141
143
  >>>
142
144
  >>> start_client(
143
145
  >>> server_address=localhost:8080,
@@ -158,6 +160,7 @@ def start_client(
158
160
  event(EventType.START_CLIENT_ENTER)
159
161
  _start_client_internal(
160
162
  server_address=server_address,
163
+ node_config={},
161
164
  load_client_app_fn=None,
162
165
  client_fn=client_fn,
163
166
  client=client,
@@ -179,6 +182,7 @@ def start_client(
179
182
  def _start_client_internal(
180
183
  *,
181
184
  server_address: str,
185
+ node_config: UserConfig,
182
186
  load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
183
187
  client_fn: Optional[ClientFnExt] = None,
184
188
  client: Optional[Client] = None,
@@ -191,7 +195,7 @@ def _start_client_internal(
191
195
  ] = None,
192
196
  max_retries: Optional[int] = None,
193
197
  max_wait_time: Optional[float] = None,
194
- partition_id: Optional[int] = None,
198
+ flwr_path: Optional[Path] = None,
195
199
  ) -> None:
196
200
  """Start a Flower client node which connects to a Flower server.
197
201
 
@@ -201,6 +205,8 @@ def _start_client_internal(
201
205
  The IPv4 or IPv6 address of the server. If the Flower
202
206
  server runs on the same machine on port 8080, then `server_address`
203
207
  would be `"[::]:8080"`.
208
+ node_config: UserConfig
209
+ The configuration of the node.
204
210
  load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None)
205
211
  A function that can be used to load a `ClientApp` instance.
206
212
  client_fn : Optional[ClientFnExt]
@@ -235,9 +241,8 @@ def _start_client_internal(
235
241
  The maximum duration before the client stops trying to
236
242
  connect to the server in case of connection error.
237
243
  If set to None, there is no limit to the total time.
238
- partitioni_id: Optional[int] (default: None)
239
- The data partition index associated with this node. Better suited for
240
- prototyping purposes.
244
+ flwr_path: Optional[Path] (default: None)
245
+ The fully resolved path containing installed Flower Apps.
241
246
  """
242
247
  if insecure is None:
243
248
  insecure = root_certificates is None
@@ -248,8 +253,7 @@ def _start_client_internal(
248
253
  if client_fn is None:
249
254
  # Wrap `Client` instance in `client_fn`
250
255
  def single_client_factory(
251
- node_id: int, # pylint: disable=unused-argument
252
- partition_id: Optional[int], # pylint: disable=unused-argument
256
+ context: Context, # pylint: disable=unused-argument
253
257
  ) -> Client:
254
258
  if client is None: # Added this to keep mypy happy
255
259
  raise ValueError(
@@ -290,7 +294,7 @@ def _start_client_internal(
290
294
  log(WARN, "Connection attempt failed, retrying...")
291
295
  else:
292
296
  log(
293
- DEBUG,
297
+ WARN,
294
298
  "Connection attempt failed, retrying in %.2f seconds",
295
299
  retry_state.actual_wait,
296
300
  )
@@ -314,9 +318,10 @@ def _start_client_internal(
314
318
  on_backoff=_on_backoff,
315
319
  )
316
320
 
317
- node_state = NodeState(partition_id=partition_id)
318
- # run_id -> (fab_id, fab_version)
319
- run_info: Dict[int, Tuple[str, str]] = {}
321
+ # NodeState gets initialized when the first connection is established
322
+ node_state: Optional[NodeState] = None
323
+
324
+ runs: Dict[int, Run] = {}
320
325
 
321
326
  while not app_state_tracker.interrupt:
322
327
  sleep_duration: int = 0
@@ -330,9 +335,31 @@ def _start_client_internal(
330
335
  ) as conn:
331
336
  receive, send, create_node, delete_node, get_run = conn
332
337
 
333
- # Register node
334
- if create_node is not None:
335
- create_node() # pylint: disable=not-callable
338
+ # Register node when connecting the first time
339
+ if node_state is None:
340
+ if create_node is None:
341
+ if transport not in ["grpc-bidi", None]:
342
+ raise NotImplementedError(
343
+ "All transports except `grpc-bidi` require "
344
+ "an implementation for `create_node()`.'"
345
+ )
346
+ # gRPC-bidi doesn't have the concept of node_id,
347
+ # so we set it to -1
348
+ node_state = NodeState(
349
+ node_id=-1,
350
+ node_config={},
351
+ )
352
+ else:
353
+ # Call create_node fn to register node
354
+ node_id: Optional[int] = ( # pylint: disable=assignment-from-none
355
+ create_node()
356
+ ) # pylint: disable=not-callable
357
+ if node_id is None:
358
+ raise ValueError("Node registration failed")
359
+ node_state = NodeState(
360
+ node_id=node_id,
361
+ node_config=node_config,
362
+ )
336
363
 
337
364
  app_state_tracker.register_signal_handler()
338
365
  while not app_state_tracker.interrupt:
@@ -366,15 +393,17 @@ def _start_client_internal(
366
393
 
367
394
  # Get run info
368
395
  run_id = message.metadata.run_id
369
- if run_id not in run_info:
396
+ if run_id not in runs:
370
397
  if get_run is not None:
371
- run_info[run_id] = get_run(run_id)
398
+ runs[run_id] = get_run(run_id)
372
399
  # If get_run is None, i.e., in grpc-bidi mode
373
400
  else:
374
- run_info[run_id] = ("", "")
401
+ runs[run_id] = Run(run_id, "", "", {})
375
402
 
376
403
  # Register context for this run
377
- node_state.register_context(run_id=run_id)
404
+ node_state.register_context(
405
+ run_id=run_id, run=runs[run_id], flwr_path=flwr_path
406
+ )
378
407
 
379
408
  # Retrieve context for this run
380
409
  context = node_state.retrieve_context(run_id=run_id)
@@ -388,7 +417,10 @@ def _start_client_internal(
388
417
  # Handle app loading and task message
389
418
  try:
390
419
  # Load ClientApp instance
391
- client_app: ClientApp = load_client_app_fn(*run_info[run_id])
420
+ run: Run = runs[run_id]
421
+ client_app: ClientApp = load_client_app_fn(
422
+ run.fab_id, run.fab_version
423
+ )
392
424
 
393
425
  # Execute ClientApp
394
426
  reply_message = client_app(message=message, context=context)
@@ -571,9 +603,9 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
571
603
  Tuple[
572
604
  Callable[[], Optional[Message]],
573
605
  Callable[[Message], None],
606
+ Optional[Callable[[], Optional[int]]],
574
607
  Optional[Callable[[], None]],
575
- Optional[Callable[[], None]],
576
- Optional[Callable[[int], Tuple[str, str]]],
608
+ Optional[Callable[[int], Run]],
577
609
  ]
578
610
  ],
579
611
  ],
flwr/client/client_app.py CHANGED
@@ -30,21 +30,41 @@ from flwr.common.logger import warn_deprecated_feature, warn_preview_feature
30
30
  from .typing import ClientAppCallable
31
31
 
32
32
 
33
+ def _alert_erroneous_client_fn() -> None:
34
+ raise ValueError(
35
+ "A `ClientApp` cannot make use of a `client_fn` that does "
36
+ "not have a signature in the form: `def client_fn(context: "
37
+ "Context)`. You can import the `Context` like this: "
38
+ "`from flwr.common import Context`"
39
+ )
40
+
41
+
33
42
  def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
34
43
  client_fn_args = inspect.signature(client_fn).parameters
44
+ first_arg = list(client_fn_args.keys())[0]
45
+
46
+ if len(client_fn_args) != 1:
47
+ _alert_erroneous_client_fn()
48
+
49
+ first_arg_type = client_fn_args[first_arg].annotation
35
50
 
36
- if not all(key in client_fn_args for key in ["node_id", "partition_id"]):
51
+ if first_arg_type is str or first_arg == "cid":
52
+ # Warn previous signature for `client_fn` seems to be used
37
53
  warn_deprecated_feature(
38
- "`client_fn` now expects a signature `def client_fn(node_id: int, "
39
- "partition_id: Optional[int])`.\nYou provided `client_fn` with signature: "
40
- f"{dict(client_fn_args.items())}"
54
+ "`client_fn` now expects a signature `def client_fn(context: Context)`."
55
+ "The provided `client_fn` has signature: "
56
+ f"{dict(client_fn_args.items())}. You can import the `Context` like this:"
57
+ " `from flwr.common import Context`"
41
58
  )
42
59
 
43
60
  # Wrap depcreated client_fn inside a function with the expected signature
44
61
  def adaptor_fn(
45
- node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
46
- ) -> Client:
47
- return client_fn(str(partition_id)) # type: ignore
62
+ context: Context,
63
+ ) -> Client: # pylint: disable=unused-argument
64
+ # if patition-id is defined, pass it. Else pass node_id that should
65
+ # always be defined during Context init.
66
+ cid = context.node_config.get("partition-id", context.node_id)
67
+ return client_fn(str(cid)) # type: ignore
48
68
 
49
69
  return adaptor_fn
50
70
 
@@ -71,7 +91,7 @@ class ClientApp:
71
91
  >>> class FlowerClient(NumPyClient):
72
92
  >>> # ...
73
93
  >>>
74
- >>> def client_fn(node_id: int, partition_id: Optional[int]):
94
+ >>> def client_fn(context: Context):
75
95
  >>> return FlowerClient().to_client()
76
96
  >>>
77
97
  >>> app = ClientApp(client_fn)
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
27
27
  from flwr.common.logger import log
28
28
  from flwr.common.message import Message
29
29
  from flwr.common.retry_invoker import RetryInvoker
30
+ from flwr.common.typing import Run
30
31
 
31
32
 
32
33
  @contextmanager
@@ -43,9 +44,9 @@ def grpc_adapter( # pylint: disable=R0913
43
44
  Tuple[
44
45
  Callable[[], Optional[Message]],
45
46
  Callable[[Message], None],
47
+ Optional[Callable[[], Optional[int]]],
46
48
  Optional[Callable[[], None]],
47
- Optional[Callable[[], None]],
48
- Optional[Callable[[int], Tuple[str, str]]],
49
+ Optional[Callable[[int], Run]],
49
50
  ]
50
51
  ]:
51
52
  """Primitives for request/response-based interaction with a server via GrpcAdapter.
@@ -38,6 +38,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
38
38
  from flwr.common.grpc import create_channel
39
39
  from flwr.common.logger import log
40
40
  from flwr.common.retry_invoker import RetryInvoker
41
+ from flwr.common.typing import Run
41
42
  from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
42
43
  ClientMessage,
43
44
  Reason,
@@ -71,9 +72,9 @@ def grpc_connection( # pylint: disable=R0913, R0915
71
72
  Tuple[
72
73
  Callable[[], Optional[Message]],
73
74
  Callable[[Message], None],
75
+ Optional[Callable[[], Optional[int]]],
74
76
  Optional[Callable[[], None]],
75
- Optional[Callable[[], None]],
76
- Optional[Callable[[int], Tuple[str, str]]],
77
+ Optional[Callable[[int], Run]],
77
78
  ]
78
79
  ]:
79
80
  """Establish a gRPC connection to a gRPC server.
@@ -40,7 +40,12 @@ from flwr.common.grpc import create_channel
40
40
  from flwr.common.logger import log
41
41
  from flwr.common.message import Message, Metadata
42
42
  from flwr.common.retry_invoker import RetryInvoker
43
- from flwr.common.serde import message_from_taskins, message_to_taskres
43
+ from flwr.common.serde import (
44
+ message_from_taskins,
45
+ message_to_taskres,
46
+ user_config_from_proto,
47
+ )
48
+ from flwr.common.typing import Run
44
49
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
45
50
  CreateNodeRequest,
46
51
  DeleteNodeRequest,
@@ -78,9 +83,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
78
83
  Tuple[
79
84
  Callable[[], Optional[Message]],
80
85
  Callable[[Message], None],
86
+ Optional[Callable[[], Optional[int]]],
81
87
  Optional[Callable[[], None]],
82
- Optional[Callable[[], None]],
83
- Optional[Callable[[int], Tuple[str, str]]],
88
+ Optional[Callable[[int], Run]],
84
89
  ]
85
90
  ]:
86
91
  """Primitives for request/response-based interaction with a server.
@@ -175,7 +180,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
175
180
  if not ping_stop_event.is_set():
176
181
  ping_stop_event.wait(next_interval)
177
182
 
178
- def create_node() -> None:
183
+ def create_node() -> Optional[int]:
179
184
  """Set create_node."""
180
185
  # Call FleetAPI
181
186
  create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
@@ -188,6 +193,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
188
193
  nonlocal node, ping_thread
189
194
  node = cast(Node, create_node_response.node)
190
195
  ping_thread = start_ping_loop(ping, ping_stop_event)
196
+ return node.node_id
191
197
 
192
198
  def delete_node() -> None:
193
199
  """Set delete_node."""
@@ -266,7 +272,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
266
272
  # Cleanup
267
273
  metadata = None
268
274
 
269
- def get_run(run_id: int) -> Tuple[str, str]:
275
+ def get_run(run_id: int) -> Run:
270
276
  # Call FleetAPI
271
277
  get_run_request = GetRunRequest(run_id=run_id)
272
278
  get_run_response: GetRunResponse = retry_invoker.invoke(
@@ -275,7 +281,12 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
275
281
  )
276
282
 
277
283
  # Return fab_id and fab_version
278
- return get_run_response.run.fab_id, get_run_response.run.fab_version
284
+ return Run(
285
+ run_id,
286
+ get_run_response.run.fab_id,
287
+ get_run_response.run.fab_version,
288
+ user_config_from_proto(get_run_response.run.override_config),
289
+ )
279
290
 
280
291
  try:
281
292
  # Yield methods
@@ -92,7 +92,7 @@ def handle_legacy_message_from_msgtype(
92
92
  client_fn: ClientFnExt, message: Message, context: Context
93
93
  ) -> Message:
94
94
  """Handle legacy message in the inner most mod."""
95
- client = client_fn(message.metadata.dst_node_id, context.partition_id)
95
+ client = client_fn(context)
96
96
 
97
97
  # Check if NumPyClient is returend
98
98
  if isinstance(client, NumPyClient):