flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.11.0.dev20240724__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (99) hide show
  1. flwr/cli/build.py +16 -2
  2. flwr/cli/config_utils.py +47 -27
  3. flwr/cli/install.py +17 -1
  4. flwr/cli/new/new.py +32 -21
  5. flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +15 -5
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +36 -13
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +16 -5
  10. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +25 -5
  12. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +22 -19
  13. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
  14. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
  15. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  16. flwr/cli/new/templates/app/code/server.jax.py.tpl +16 -8
  17. flwr/cli/new/templates/app/code/server.mlx.py.tpl +12 -7
  18. flwr/cli/new/templates/app/code/server.numpy.py.tpl +16 -8
  19. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
  20. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -10
  21. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
  22. flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +14 -2
  23. flwr/cli/new/templates/app/code/task.mlx.py.tpl +14 -2
  24. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -3
  25. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +13 -1
  26. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
  27. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  28. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
  29. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
  30. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
  31. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
  32. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
  33. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
  34. flwr/cli/run/run.py +133 -54
  35. flwr/client/app.py +56 -24
  36. flwr/client/client_app.py +28 -8
  37. flwr/client/grpc_adapter_client/connection.py +3 -2
  38. flwr/client/grpc_client/connection.py +3 -2
  39. flwr/client/grpc_rere_client/connection.py +17 -6
  40. flwr/client/message_handler/message_handler.py +1 -1
  41. flwr/client/node_state.py +59 -12
  42. flwr/client/node_state_tests.py +4 -3
  43. flwr/client/rest_client/connection.py +19 -8
  44. flwr/client/supernode/app.py +39 -39
  45. flwr/client/typing.py +2 -2
  46. flwr/common/config.py +92 -2
  47. flwr/common/constant.py +3 -0
  48. flwr/common/context.py +24 -9
  49. flwr/common/logger.py +25 -0
  50. flwr/common/object_ref.py +84 -21
  51. flwr/common/serde.py +45 -0
  52. flwr/common/telemetry.py +17 -0
  53. flwr/common/typing.py +5 -0
  54. flwr/proto/common_pb2.py +36 -0
  55. flwr/proto/common_pb2.pyi +121 -0
  56. flwr/proto/common_pb2_grpc.py +4 -0
  57. flwr/proto/common_pb2_grpc.pyi +4 -0
  58. flwr/proto/driver_pb2.py +24 -19
  59. flwr/proto/driver_pb2.pyi +21 -1
  60. flwr/proto/exec_pb2.py +20 -11
  61. flwr/proto/exec_pb2.pyi +41 -1
  62. flwr/proto/run_pb2.py +12 -7
  63. flwr/proto/run_pb2.pyi +22 -1
  64. flwr/proto/task_pb2.py +7 -8
  65. flwr/server/__init__.py +2 -0
  66. flwr/server/compat/legacy_context.py +5 -4
  67. flwr/server/driver/grpc_driver.py +82 -140
  68. flwr/server/run_serverapp.py +40 -18
  69. flwr/server/server_app.py +56 -10
  70. flwr/server/serverapp_components.py +52 -0
  71. flwr/server/superlink/driver/driver_servicer.py +18 -3
  72. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  73. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  74. flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
  75. flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
  76. flwr/server/superlink/fleet/vce/vce_api.py +149 -117
  77. flwr/server/superlink/state/in_memory_state.py +11 -3
  78. flwr/server/superlink/state/sqlite_state.py +23 -8
  79. flwr/server/superlink/state/state.py +7 -2
  80. flwr/server/typing.py +2 -0
  81. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  82. flwr/simulation/__init__.py +1 -1
  83. flwr/simulation/app.py +4 -3
  84. flwr/simulation/ray_transport/ray_actor.py +15 -19
  85. flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
  86. flwr/simulation/run_simulation.py +269 -70
  87. flwr/superexec/app.py +17 -11
  88. flwr/superexec/deployment.py +111 -35
  89. flwr/superexec/exec_grpc.py +5 -1
  90. flwr/superexec/exec_servicer.py +6 -1
  91. flwr/superexec/executor.py +21 -0
  92. flwr/superexec/simulation.py +181 -0
  93. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/METADATA +3 -2
  94. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/RECORD +97 -91
  95. flwr/cli/new/templates/app/code/server.hf.py.tpl +0 -17
  96. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +0 -37
  97. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/LICENSE +0 -0
  98. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/WHEEL +0 -0
  99. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/entry_points.txt +0 -0
@@ -6,10 +6,7 @@ build-backend = "hatchling.build"
6
6
  name = "$package_name"
7
7
  version = "1.0.0"
8
8
  description = ""
9
- authors = [
10
- { name = "The Flower Authors", email = "hello@flower.ai" },
11
- ]
12
- license = { text = "Apache License (2.0)" }
9
+ license = "Apache-2.0"
13
10
  dependencies = [
14
11
  "flwr[simulation]>=1.9.0,<2.0",
15
12
  "flwr-datasets[vision]>=0.0.2,<1.0.0",
@@ -20,15 +17,19 @@ dependencies = [
20
17
  [tool.hatch.build.targets.wheel]
21
18
  packages = ["."]
22
19
 
23
- [flower]
20
+ [tool.flwr.app]
24
21
  publisher = "$username"
25
22
 
26
- [flower.components]
27
- serverapp = "$import_name.server:app"
28
- clientapp = "$import_name.client:app"
23
+ [tool.flwr.app.components]
24
+ serverapp = "$import_name.server_app:app"
25
+ clientapp = "$import_name.client_app:app"
26
+
27
+ [tool.flwr.app.config]
28
+ num-server-rounds = 3
29
+ local-epochs = 1
29
30
 
30
- [flower.engine]
31
- name = "simulation"
31
+ [tool.flwr.federations]
32
+ default = "local-simulation"
32
33
 
33
- [flower.engine.simulation.supernode]
34
- num = 2
34
+ [tool.flwr.federations.local-simulation]
35
+ options.num-supernodes = 10
@@ -6,10 +6,7 @@ build-backend = "hatchling.build"
6
6
  name = "$package_name"
7
7
  version = "1.0.0"
8
8
  description = ""
9
- authors = [
10
- { name = "The Flower Authors", email = "hello@flower.ai" },
11
- ]
12
- license = { text = "Apache License (2.0)" }
9
+ license = "Apache-2.0"
13
10
  dependencies = [
14
11
  "flwr[simulation]>=1.9.0,<2.0",
15
12
  "flwr-datasets[vision]>=0.0.2,<1.0.0",
@@ -19,15 +16,18 @@ dependencies = [
19
16
  [tool.hatch.build.targets.wheel]
20
17
  packages = ["."]
21
18
 
22
- [flower]
19
+ [tool.flwr.app]
23
20
  publisher = "$username"
24
21
 
25
- [flower.components]
26
- serverapp = "$import_name.server:app"
27
- clientapp = "$import_name.client:app"
22
+ [tool.flwr.app.components]
23
+ serverapp = "$import_name.server_app:app"
24
+ clientapp = "$import_name.client_app:app"
25
+
26
+ [tool.flwr.app.config]
27
+ num-server-rounds = 3
28
28
 
29
- [flower.engine]
30
- name = "simulation"
29
+ [tool.flwr.federations]
30
+ default = "local-simulation"
31
31
 
32
- [flower.engine.simulation.supernode]
33
- num = 2
32
+ [tool.flwr.federations.local-simulation]
33
+ options.num-supernodes = 10
@@ -6,10 +6,7 @@ build-backend = "hatchling.build"
6
6
  name = "$package_name"
7
7
  version = "1.0.0"
8
8
  description = ""
9
- authors = [
10
- { name = "The Flower Authors", email = "hello@flower.ai" },
11
- ]
12
- license = { text = "Apache License (2.0)" }
9
+ license = "Apache-2.0"
13
10
  dependencies = [
14
11
  "flwr[simulation]>=1.9.0,<2.0",
15
12
  "flwr-datasets[vision]>=0.0.2,<1.0.0",
@@ -19,15 +16,21 @@ dependencies = [
19
16
  [tool.hatch.build.targets.wheel]
20
17
  packages = ["."]
21
18
 
22
- [flower]
19
+ [tool.flwr.app]
23
20
  publisher = "$username"
24
21
 
25
- [flower.components]
26
- serverapp = "$import_name.server:app"
27
- clientapp = "$import_name.client:app"
22
+ [tool.flwr.app.components]
23
+ serverapp = "$import_name.server_app:app"
24
+ clientapp = "$import_name.client_app:app"
25
+
26
+ [tool.flwr.app.config]
27
+ num-server-rounds = 3
28
+ local-epochs = 1
29
+ batch-size = 32
30
+ verbose = false
28
31
 
29
- [flower.engine]
30
- name = "simulation"
32
+ [tool.flwr.federations]
33
+ default = "local-simulation"
31
34
 
32
- [flower.engine.simulation.supernode]
33
- num = 2
35
+ [tool.flwr.federations.local-simulation]
36
+ options.num-supernodes = 10
flwr/cli/run/run.py CHANGED
@@ -14,59 +14,52 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface `run` command."""
16
16
 
17
+ import subprocess
17
18
  import sys
18
- from enum import Enum
19
19
  from logging import DEBUG
20
20
  from pathlib import Path
21
- from typing import Optional
21
+ from typing import Any, Dict, List, Optional
22
22
 
23
23
  import typer
24
24
  from typing_extensions import Annotated
25
25
 
26
- from flwr.cli import config_utils
27
26
  from flwr.cli.build import build
28
- from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
27
+ from flwr.cli.config_utils import load_and_validate
28
+ from flwr.common.config import flatten_dict, parse_config_args
29
29
  from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
30
30
  from flwr.common.logger import log
31
+ from flwr.common.serde import user_config_to_proto
31
32
  from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
32
33
  from flwr.proto.exec_pb2_grpc import ExecStub
33
- from flwr.simulation.run_simulation import _run_simulation
34
-
35
-
36
- class Engine(str, Enum):
37
- """Enum defining the engine to run on."""
38
-
39
- SIMULATION = "simulation"
40
34
 
41
35
 
42
36
  # pylint: disable-next=too-many-locals
43
37
  def run(
44
- engine: Annotated[
45
- Optional[Engine],
46
- typer.Option(
47
- case_sensitive=False,
48
- help="The engine to run FL with (currently only simulation is supported).",
49
- ),
38
+ app_dir: Annotated[
39
+ Path,
40
+ typer.Argument(help="Path of the Flower project to run."),
41
+ ] = Path("."),
42
+ federation: Annotated[
43
+ Optional[str],
44
+ typer.Argument(help="Name of the federation to run the app on."),
50
45
  ] = None,
51
- use_superexec: Annotated[
52
- bool,
46
+ config_overrides: Annotated[
47
+ Optional[List[str]],
53
48
  typer.Option(
54
- case_sensitive=False, help="Use this flag to use the new SuperExec API"
49
+ "--run-config",
50
+ "-c",
51
+ help="Override configuration key-value pairs, should be of the format:\n\n"
52
+ "`--run-config key1=value1,key2=value2 --run-config key3=value3`\n\n"
53
+ "Note that `key1`, `key2`, and `key3` in this example need to exist "
54
+ "inside the `pyproject.toml` in order to be properly overriden.",
55
55
  ),
56
- ] = False,
57
- directory: Annotated[
58
- Optional[Path],
59
- typer.Option(help="Path of the Flower project to run"),
60
56
  ] = None,
61
57
  ) -> None:
62
58
  """Run Flower project."""
63
- if use_superexec:
64
- _start_superexec_run(directory)
65
- return
66
-
67
59
  typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
68
60
 
69
- config, errors, warnings = config_utils.load_and_validate()
61
+ pyproject_path = app_dir / "pyproject.toml" if app_dir else None
62
+ config, errors, warnings = load_and_validate(path=pyproject_path)
70
63
 
71
64
  if config is None:
72
65
  typer.secho(
@@ -88,50 +81,136 @@ def run(
88
81
 
89
82
  typer.secho("Success", fg=typer.colors.GREEN)
90
83
 
91
- server_app_ref = config["flower"]["components"]["serverapp"]
92
- client_app_ref = config["flower"]["components"]["clientapp"]
84
+ federation = federation or config["tool"]["flwr"]["federations"].get("default")
93
85
 
94
- if engine is None:
95
- engine = config["flower"]["engine"]["name"]
96
-
97
- if engine == Engine.SIMULATION:
98
- num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]
99
- backend_config = config["flower"]["engine"]["simulation"].get(
100
- "backend_config", None
101
- )
102
-
103
- typer.secho("Starting run... ", fg=typer.colors.BLUE)
104
- _run_simulation(
105
- server_app_attr=server_app_ref,
106
- client_app_attr=client_app_ref,
107
- num_supernodes=num_supernodes,
108
- backend_config=backend_config,
86
+ if federation is None:
87
+ typer.secho(
88
+ "❌ No federation name was provided and the project's `pyproject.toml` "
89
+ "doesn't declare a default federation (with a SuperExec address or an "
90
+ "`options.num-supernodes` value).",
91
+ fg=typer.colors.RED,
92
+ bold=True,
109
93
  )
110
- else:
94
+ raise typer.Exit(code=1)
95
+
96
+ # Validate the federation exists in the configuration
97
+ federation_config = config["tool"]["flwr"]["federations"].get(federation)
98
+ if federation_config is None:
99
+ available_feds = {
100
+ fed for fed in config["tool"]["flwr"]["federations"] if fed != "default"
101
+ }
111
102
  typer.secho(
112
- f"Engine '{engine}' is not yet supported in `flwr run`",
103
+ f" There is no `{federation}` federation declared in "
104
+ "`pyproject.toml`.\n The following federations were found:\n\n"
105
+ + "\n".join(available_feds),
113
106
  fg=typer.colors.RED,
114
107
  bold=True,
115
108
  )
109
+ raise typer.Exit(code=1)
110
+
111
+ if "address" in federation_config:
112
+ _run_with_superexec(federation_config, app_dir, config_overrides)
113
+ else:
114
+ _run_without_superexec(app_dir, federation_config, federation, config_overrides)
116
115
 
117
116
 
118
- def _start_superexec_run(directory: Optional[Path]) -> None:
117
+ def _run_with_superexec(
118
+ federation_config: Dict[str, Any],
119
+ app_dir: Optional[Path],
120
+ config_overrides: Optional[List[str]],
121
+ ) -> None:
122
+
119
123
  def on_channel_state_change(channel_connectivity: str) -> None:
120
124
  """Log channel connectivity."""
121
125
  log(DEBUG, channel_connectivity)
122
126
 
127
+ insecure_str = federation_config.get("insecure")
128
+ if root_certificates := federation_config.get("root-certificates"):
129
+ root_certificates_bytes = Path(root_certificates).read_bytes()
130
+ if insecure := bool(insecure_str):
131
+ typer.secho(
132
+ "❌ `root_certificates` were provided but the `insecure` parameter"
133
+ "is set to `True`.",
134
+ fg=typer.colors.RED,
135
+ bold=True,
136
+ )
137
+ raise typer.Exit(code=1)
138
+ else:
139
+ root_certificates_bytes = None
140
+ if insecure_str is None:
141
+ typer.secho(
142
+ "❌ To disable TLS, set `insecure = true` in `pyproject.toml`.",
143
+ fg=typer.colors.RED,
144
+ bold=True,
145
+ )
146
+ raise typer.Exit(code=1)
147
+ if not (insecure := bool(insecure_str)):
148
+ typer.secho(
149
+ "❌ No certificate were given yet `insecure` is set to `False`.",
150
+ fg=typer.colors.RED,
151
+ bold=True,
152
+ )
153
+ raise typer.Exit(code=1)
154
+
123
155
  channel = create_channel(
124
- server_address=SUPEREXEC_DEFAULT_ADDRESS,
125
- insecure=True,
126
- root_certificates=None,
156
+ server_address=federation_config["address"],
157
+ insecure=insecure,
158
+ root_certificates=root_certificates_bytes,
127
159
  max_message_length=GRPC_MAX_MESSAGE_LENGTH,
128
160
  interceptors=None,
129
161
  )
130
162
  channel.subscribe(on_channel_state_change)
131
163
  stub = ExecStub(channel)
132
164
 
133
- fab_path = build(directory)
165
+ fab_path = build(app_dir)
134
166
 
135
- req = StartRunRequest(fab_file=Path(fab_path).read_bytes())
167
+ req = StartRunRequest(
168
+ fab_file=Path(fab_path).read_bytes(),
169
+ override_config=user_config_to_proto(
170
+ parse_config_args(config_overrides, separator=",")
171
+ ),
172
+ federation_config=user_config_to_proto(
173
+ flatten_dict(federation_config.get("options"))
174
+ ),
175
+ )
136
176
  res = stub.StartRun(req)
137
177
  typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
178
+
179
+
180
+ def _run_without_superexec(
181
+ app_path: Optional[Path],
182
+ federation_config: Dict[str, Any],
183
+ federation: str,
184
+ config_overrides: Optional[List[str]],
185
+ ) -> None:
186
+ try:
187
+ num_supernodes = federation_config["options"]["num-supernodes"]
188
+ except KeyError as err:
189
+ typer.secho(
190
+ "❌ The project's `pyproject.toml` needs to declare the number of"
191
+ " SuperNodes in the simulation. To simulate 10 SuperNodes,"
192
+ " use the following notation:\n\n"
193
+ f"[tool.flwr.federations.{federation}]\n"
194
+ "options.num-supernodes = 10\n",
195
+ fg=typer.colors.RED,
196
+ bold=True,
197
+ )
198
+ raise typer.Exit(code=1) from err
199
+
200
+ command = [
201
+ "flower-simulation",
202
+ "--app",
203
+ f"{app_path}",
204
+ "--num-supernodes",
205
+ f"{num_supernodes}",
206
+ ]
207
+
208
+ if config_overrides:
209
+ command.extend(["--run-config", f"{','.join(config_overrides)}"])
210
+
211
+ # Run the simulation
212
+ subprocess.run(
213
+ command,
214
+ check=True,
215
+ text=True,
216
+ )
flwr/client/app.py CHANGED
@@ -18,7 +18,8 @@ import signal
18
18
  import sys
19
19
  import time
20
20
  from dataclasses import dataclass
21
- from logging import DEBUG, ERROR, INFO, WARN
21
+ from logging import ERROR, INFO, WARN
22
+ from pathlib import Path
22
23
  from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union
23
24
 
24
25
  from cryptography.hazmat.primitives.asymmetric import ec
@@ -27,7 +28,7 @@ from grpc import RpcError
27
28
  from flwr.client.client import Client
28
29
  from flwr.client.client_app import ClientApp, LoadClientAppError
29
30
  from flwr.client.typing import ClientFnExt
30
- from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
31
+ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
31
32
  from flwr.common.address import parse_address
32
33
  from flwr.common.constant import (
33
34
  MISSING_EXTRA_REST,
@@ -41,6 +42,7 @@ from flwr.common.constant import (
41
42
  from flwr.common.logger import log, warn_deprecated_feature
42
43
  from flwr.common.message import Error
43
44
  from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
45
+ from flwr.common.typing import Run, UserConfig
44
46
 
45
47
  from .grpc_adapter_client.connection import grpc_adapter
46
48
  from .grpc_client.connection import grpc_connection
@@ -136,8 +138,8 @@ def start_client(
136
138
 
137
139
  Starting an SSL-enabled gRPC client using system certificates:
138
140
 
139
- >>> def client_fn(node_id: int, partition_id: Optional[int]):
140
- >>> return FlowerClient()
141
+ >>> def client_fn(context: Context):
142
+ >>> return FlowerClient().to_client()
141
143
  >>>
142
144
  >>> start_client(
143
145
  >>> server_address=localhost:8080,
@@ -158,6 +160,7 @@ def start_client(
158
160
  event(EventType.START_CLIENT_ENTER)
159
161
  _start_client_internal(
160
162
  server_address=server_address,
163
+ node_config={},
161
164
  load_client_app_fn=None,
162
165
  client_fn=client_fn,
163
166
  client=client,
@@ -179,6 +182,7 @@ def start_client(
179
182
  def _start_client_internal(
180
183
  *,
181
184
  server_address: str,
185
+ node_config: UserConfig,
182
186
  load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
183
187
  client_fn: Optional[ClientFnExt] = None,
184
188
  client: Optional[Client] = None,
@@ -191,7 +195,7 @@ def _start_client_internal(
191
195
  ] = None,
192
196
  max_retries: Optional[int] = None,
193
197
  max_wait_time: Optional[float] = None,
194
- partition_id: Optional[int] = None,
198
+ flwr_path: Optional[Path] = None,
195
199
  ) -> None:
196
200
  """Start a Flower client node which connects to a Flower server.
197
201
 
@@ -201,6 +205,8 @@ def _start_client_internal(
201
205
  The IPv4 or IPv6 address of the server. If the Flower
202
206
  server runs on the same machine on port 8080, then `server_address`
203
207
  would be `"[::]:8080"`.
208
+ node_config: UserConfig
209
+ The configuration of the node.
204
210
  load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None)
205
211
  A function that can be used to load a `ClientApp` instance.
206
212
  client_fn : Optional[ClientFnExt]
@@ -235,9 +241,8 @@ def _start_client_internal(
235
241
  The maximum duration before the client stops trying to
236
242
  connect to the server in case of connection error.
237
243
  If set to None, there is no limit to the total time.
238
- partitioni_id: Optional[int] (default: None)
239
- The data partition index associated with this node. Better suited for
240
- prototyping purposes.
244
+ flwr_path: Optional[Path] (default: None)
245
+ The fully resolved path containing installed Flower Apps.
241
246
  """
242
247
  if insecure is None:
243
248
  insecure = root_certificates is None
@@ -248,8 +253,7 @@ def _start_client_internal(
248
253
  if client_fn is None:
249
254
  # Wrap `Client` instance in `client_fn`
250
255
  def single_client_factory(
251
- node_id: int, # pylint: disable=unused-argument
252
- partition_id: Optional[int], # pylint: disable=unused-argument
256
+ context: Context, # pylint: disable=unused-argument
253
257
  ) -> Client:
254
258
  if client is None: # Added this to keep mypy happy
255
259
  raise ValueError(
@@ -290,7 +294,7 @@ def _start_client_internal(
290
294
  log(WARN, "Connection attempt failed, retrying...")
291
295
  else:
292
296
  log(
293
- DEBUG,
297
+ WARN,
294
298
  "Connection attempt failed, retrying in %.2f seconds",
295
299
  retry_state.actual_wait,
296
300
  )
@@ -314,9 +318,10 @@ def _start_client_internal(
314
318
  on_backoff=_on_backoff,
315
319
  )
316
320
 
317
- node_state = NodeState(partition_id=partition_id)
318
- # run_id -> (fab_id, fab_version)
319
- run_info: Dict[int, Tuple[str, str]] = {}
321
+ # NodeState gets initialized when the first connection is established
322
+ node_state: Optional[NodeState] = None
323
+
324
+ runs: Dict[int, Run] = {}
320
325
 
321
326
  while not app_state_tracker.interrupt:
322
327
  sleep_duration: int = 0
@@ -330,9 +335,31 @@ def _start_client_internal(
330
335
  ) as conn:
331
336
  receive, send, create_node, delete_node, get_run = conn
332
337
 
333
- # Register node
334
- if create_node is not None:
335
- create_node() # pylint: disable=not-callable
338
+ # Register node when connecting the first time
339
+ if node_state is None:
340
+ if create_node is None:
341
+ if transport not in ["grpc-bidi", None]:
342
+ raise NotImplementedError(
343
+ "All transports except `grpc-bidi` require "
344
+ "an implementation for `create_node()`.'"
345
+ )
346
+ # gRPC-bidi doesn't have the concept of node_id,
347
+ # so we set it to -1
348
+ node_state = NodeState(
349
+ node_id=-1,
350
+ node_config={},
351
+ )
352
+ else:
353
+ # Call create_node fn to register node
354
+ node_id: Optional[int] = ( # pylint: disable=assignment-from-none
355
+ create_node()
356
+ ) # pylint: disable=not-callable
357
+ if node_id is None:
358
+ raise ValueError("Node registration failed")
359
+ node_state = NodeState(
360
+ node_id=node_id,
361
+ node_config=node_config,
362
+ )
336
363
 
337
364
  app_state_tracker.register_signal_handler()
338
365
  while not app_state_tracker.interrupt:
@@ -366,15 +393,17 @@ def _start_client_internal(
366
393
 
367
394
  # Get run info
368
395
  run_id = message.metadata.run_id
369
- if run_id not in run_info:
396
+ if run_id not in runs:
370
397
  if get_run is not None:
371
- run_info[run_id] = get_run(run_id)
398
+ runs[run_id] = get_run(run_id)
372
399
  # If get_run is None, i.e., in grpc-bidi mode
373
400
  else:
374
- run_info[run_id] = ("", "")
401
+ runs[run_id] = Run(run_id, "", "", {})
375
402
 
376
403
  # Register context for this run
377
- node_state.register_context(run_id=run_id)
404
+ node_state.register_context(
405
+ run_id=run_id, run=runs[run_id], flwr_path=flwr_path
406
+ )
378
407
 
379
408
  # Retrieve context for this run
380
409
  context = node_state.retrieve_context(run_id=run_id)
@@ -388,7 +417,10 @@ def _start_client_internal(
388
417
  # Handle app loading and task message
389
418
  try:
390
419
  # Load ClientApp instance
391
- client_app: ClientApp = load_client_app_fn(*run_info[run_id])
420
+ run: Run = runs[run_id]
421
+ client_app: ClientApp = load_client_app_fn(
422
+ run.fab_id, run.fab_version
423
+ )
392
424
 
393
425
  # Execute ClientApp
394
426
  reply_message = client_app(message=message, context=context)
@@ -571,9 +603,9 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
571
603
  Tuple[
572
604
  Callable[[], Optional[Message]],
573
605
  Callable[[Message], None],
606
+ Optional[Callable[[], Optional[int]]],
574
607
  Optional[Callable[[], None]],
575
- Optional[Callable[[], None]],
576
- Optional[Callable[[int], Tuple[str, str]]],
608
+ Optional[Callable[[int], Run]],
577
609
  ]
578
610
  ],
579
611
  ],
flwr/client/client_app.py CHANGED
@@ -30,21 +30,41 @@ from flwr.common.logger import warn_deprecated_feature, warn_preview_feature
30
30
  from .typing import ClientAppCallable
31
31
 
32
32
 
33
+ def _alert_erroneous_client_fn() -> None:
34
+ raise ValueError(
35
+ "A `ClientApp` cannot make use of a `client_fn` that does "
36
+ "not have a signature in the form: `def client_fn(context: "
37
+ "Context)`. You can import the `Context` like this: "
38
+ "`from flwr.common import Context`"
39
+ )
40
+
41
+
33
42
  def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
34
43
  client_fn_args = inspect.signature(client_fn).parameters
44
+ first_arg = list(client_fn_args.keys())[0]
45
+
46
+ if len(client_fn_args) != 1:
47
+ _alert_erroneous_client_fn()
48
+
49
+ first_arg_type = client_fn_args[first_arg].annotation
35
50
 
36
- if not all(key in client_fn_args for key in ["node_id", "partition_id"]):
51
+ if first_arg_type is str or first_arg == "cid":
52
+ # Warn previous signature for `client_fn` seems to be used
37
53
  warn_deprecated_feature(
38
- "`client_fn` now expects a signature `def client_fn(node_id: int, "
39
- "partition_id: Optional[int])`.\nYou provided `client_fn` with signature: "
40
- f"{dict(client_fn_args.items())}"
54
+ "`client_fn` now expects a signature `def client_fn(context: Context)`."
55
+ "The provided `client_fn` has signature: "
56
+ f"{dict(client_fn_args.items())}. You can import the `Context` like this:"
57
+ " `from flwr.common import Context`"
41
58
  )
42
59
 
43
60
  # Wrap depcreated client_fn inside a function with the expected signature
44
61
  def adaptor_fn(
45
- node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
46
- ) -> Client:
47
- return client_fn(str(partition_id)) # type: ignore
62
+ context: Context,
63
+ ) -> Client: # pylint: disable=unused-argument
64
+ # if patition-id is defined, pass it. Else pass node_id that should
65
+ # always be defined during Context init.
66
+ cid = context.node_config.get("partition-id", context.node_id)
67
+ return client_fn(str(cid)) # type: ignore
48
68
 
49
69
  return adaptor_fn
50
70
 
@@ -71,7 +91,7 @@ class ClientApp:
71
91
  >>> class FlowerClient(NumPyClient):
72
92
  >>> # ...
73
93
  >>>
74
- >>> def client_fn(node_id: int, partition_id: Optional[int]):
94
+ >>> def client_fn(context: Context):
75
95
  >>> return FlowerClient().to_client()
76
96
  >>>
77
97
  >>> app = ClientApp(client_fn)
@@ -27,6 +27,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
27
27
  from flwr.common.logger import log
28
28
  from flwr.common.message import Message
29
29
  from flwr.common.retry_invoker import RetryInvoker
30
+ from flwr.common.typing import Run
30
31
 
31
32
 
32
33
  @contextmanager
@@ -43,9 +44,9 @@ def grpc_adapter( # pylint: disable=R0913
43
44
  Tuple[
44
45
  Callable[[], Optional[Message]],
45
46
  Callable[[Message], None],
47
+ Optional[Callable[[], Optional[int]]],
46
48
  Optional[Callable[[], None]],
47
- Optional[Callable[[], None]],
48
- Optional[Callable[[int], Tuple[str, str]]],
49
+ Optional[Callable[[int], Run]],
49
50
  ]
50
51
  ]:
51
52
  """Primitives for request/response-based interaction with a server via GrpcAdapter.
@@ -38,6 +38,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
38
38
  from flwr.common.grpc import create_channel
39
39
  from flwr.common.logger import log
40
40
  from flwr.common.retry_invoker import RetryInvoker
41
+ from flwr.common.typing import Run
41
42
  from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
42
43
  ClientMessage,
43
44
  Reason,
@@ -71,9 +72,9 @@ def grpc_connection( # pylint: disable=R0913, R0915
71
72
  Tuple[
72
73
  Callable[[], Optional[Message]],
73
74
  Callable[[Message], None],
75
+ Optional[Callable[[], Optional[int]]],
74
76
  Optional[Callable[[], None]],
75
- Optional[Callable[[], None]],
76
- Optional[Callable[[int], Tuple[str, str]]],
77
+ Optional[Callable[[int], Run]],
77
78
  ]
78
79
  ]:
79
80
  """Establish a gRPC connection to a gRPC server.