flwr-nightly 1.10.0.dev20240619__py3-none-any.whl → 1.10.0.dev20240707__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (109) hide show
  1. flwr/cli/app.py +3 -0
  2. flwr/cli/build.py +5 -9
  3. flwr/cli/new/new.py +104 -28
  4. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  5. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  6. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +86 -0
  7. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +124 -0
  8. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  9. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  10. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  11. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  12. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  13. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +42 -0
  14. flwr/cli/run/run.py +21 -5
  15. flwr/client/__init__.py +2 -0
  16. flwr/client/app.py +15 -10
  17. flwr/client/client_app.py +30 -5
  18. flwr/client/dpfedavg_numpy_client.py +1 -1
  19. flwr/client/grpc_rere_client/__init__.py +1 -1
  20. flwr/client/grpc_rere_client/connection.py +1 -1
  21. flwr/client/message_handler/__init__.py +1 -1
  22. flwr/client/message_handler/message_handler.py +4 -5
  23. flwr/client/mod/__init__.py +1 -1
  24. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  25. flwr/client/mod/utils.py +1 -1
  26. flwr/client/node_state.py +6 -3
  27. flwr/client/node_state_tests.py +1 -1
  28. flwr/client/rest_client/__init__.py +1 -1
  29. flwr/client/rest_client/connection.py +1 -1
  30. flwr/client/supernode/app.py +12 -4
  31. flwr/client/typing.py +2 -1
  32. flwr/common/address.py +1 -1
  33. flwr/common/config.py +8 -6
  34. flwr/common/constant.py +4 -1
  35. flwr/common/context.py +11 -1
  36. flwr/common/date.py +1 -1
  37. flwr/common/dp.py +1 -1
  38. flwr/common/grpc.py +1 -1
  39. flwr/common/logger.py +13 -0
  40. flwr/common/message.py +0 -17
  41. flwr/common/secure_aggregation/__init__.py +1 -1
  42. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  43. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  44. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
  45. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  46. flwr/common/secure_aggregation/quantization.py +1 -1
  47. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  48. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  49. flwr/common/version.py +14 -0
  50. flwr/server/compat/app.py +1 -1
  51. flwr/server/compat/app_utils.py +1 -1
  52. flwr/server/compat/driver_client_proxy.py +1 -1
  53. flwr/server/driver/driver.py +6 -0
  54. flwr/server/driver/grpc_driver.py +85 -63
  55. flwr/server/driver/inmemory_driver.py +28 -26
  56. flwr/server/run_serverapp.py +61 -18
  57. flwr/server/strategy/bulyan.py +1 -1
  58. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  59. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  60. flwr/server/strategy/fedadagrad.py +1 -1
  61. flwr/server/strategy/fedadam.py +1 -1
  62. flwr/server/strategy/fedavg_android.py +1 -1
  63. flwr/server/strategy/fedavgm.py +1 -1
  64. flwr/server/strategy/fedmedian.py +1 -1
  65. flwr/server/strategy/fedopt.py +1 -1
  66. flwr/server/strategy/fedprox.py +1 -1
  67. flwr/server/strategy/fedxgb_bagging.py +1 -1
  68. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  69. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  70. flwr/server/strategy/fedyogi.py +1 -1
  71. flwr/server/strategy/krum.py +1 -1
  72. flwr/server/strategy/qfedavg.py +1 -1
  73. flwr/server/superlink/driver/__init__.py +1 -1
  74. flwr/server/superlink/driver/driver_grpc.py +1 -1
  75. flwr/server/superlink/driver/driver_servicer.py +15 -3
  76. flwr/server/superlink/fleet/__init__.py +1 -1
  77. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  78. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  79. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  80. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  81. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -1
  82. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  83. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  84. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  85. flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
  86. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  87. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
  88. flwr/server/superlink/fleet/vce/backend/raybackend.py +45 -26
  89. flwr/server/superlink/fleet/vce/vce_api.py +3 -8
  90. flwr/server/superlink/state/__init__.py +1 -1
  91. flwr/server/superlink/state/in_memory_state.py +5 -5
  92. flwr/server/superlink/state/sqlite_state.py +5 -5
  93. flwr/server/superlink/state/state.py +1 -1
  94. flwr/server/superlink/state/state_factory.py +11 -2
  95. flwr/server/superlink/state/utils.py +6 -0
  96. flwr/server/utils/__init__.py +1 -1
  97. flwr/server/utils/tensorboard.py +1 -1
  98. flwr/simulation/__init__.py +1 -1
  99. flwr/simulation/app.py +52 -37
  100. flwr/simulation/ray_transport/__init__.py +1 -1
  101. flwr/simulation/ray_transport/ray_actor.py +0 -6
  102. flwr/simulation/ray_transport/ray_client_proxy.py +17 -10
  103. flwr/simulation/run_simulation.py +47 -28
  104. flwr/superexec/deployment.py +109 -0
  105. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/METADATA +2 -1
  106. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/RECORD +109 -98
  107. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/LICENSE +0 -0
  108. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/WHEEL +0 -0
  109. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240707.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,48 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ from $import_name.client import set_parameters
4
+ from $import_name.models import get_model
5
+
6
+
7
+ # Get function that will be executed by the strategy's evaluate() method
8
+ # Here we use it to save global model checkpoints
9
+ def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
10
+ """Return an evaluation function for saving global model."""
11
+
12
+ def evaluate(server_round: int, parameters, config):
13
+ # Save model
14
+ if server_round != 0 and (
15
+ server_round == total_round or server_round % save_every_round == 0
16
+ ):
17
+ # Init model
18
+ model = get_model(model_cfg)
19
+ set_parameters(model, parameters)
20
+
21
+ model.save_pretrained(f"{save_path}/peft_{server_round}")
22
+
23
+ return 0.0, {}
24
+
25
+ return evaluate
26
+
27
+
28
+ def get_on_fit_config():
29
+ """
30
+ Return a function that will be used to construct the config
31
+ that the client's fit() method will receive.
32
+ """
33
+
34
+ def fit_config_fn(server_round: int):
35
+ fit_config = {"current_round": server_round}
36
+ return fit_config
37
+
38
+ return fit_config_fn
39
+
40
+
41
+ def fit_weighted_average(metrics):
42
+ """Aggregate (federated) evaluation metrics."""
43
+ # Multiply accuracy of each client by number of examples used
44
+ losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
45
+ examples = [num_examples for num_examples, _ in metrics]
46
+
47
+ # Aggregate and return custom metric (weighted average)
48
+ return {"train_loss": sum(losses) / sum(examples)}
@@ -0,0 +1,11 @@
1
+ # Federated Instruction Tuning (static)
2
+ ---
3
+ dataset:
4
+ name: $dataset_name
5
+
6
+ # FL experimental settings
7
+ num_clients: $num_clients # total number of clients
8
+ num_rounds: 200
9
+ partitioner:
10
+ _target_: flwr_datasets.partitioner.IidPartitioner
11
+ num_partitions: $num_clients
@@ -0,0 +1,42 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "$package_name"
7
+ version = "1.0.0"
8
+ description = ""
9
+ authors = [
10
+ { name = "The Flower Authors", email = "hello@flower.ai" },
11
+ ]
12
+ license = { text = "Apache License (2.0)" }
13
+ dependencies = [
14
+ "flwr[simulation]>=1.9.0,<2.0",
15
+ "flwr-datasets>=0.1.0,<1.0.0",
16
+ "hydra-core==1.3.2",
17
+ "trl==0.8.1",
18
+ "bitsandbytes==0.43.0",
19
+ "scipy==1.13.0",
20
+ "peft==0.6.2",
21
+ "transformers==4.39.3",
22
+ "sentencepiece==0.2.0",
23
+ ]
24
+
25
+ [tool.hatch.build.targets.wheel]
26
+ packages = ["."]
27
+
28
+ [flower]
29
+ publisher = "$username"
30
+
31
+ [flower.components]
32
+ serverapp = "$import_name.app:server"
33
+ clientapp = "$import_name.app:client"
34
+
35
+ [flower.engine]
36
+ name = "simulation"
37
+
38
+ [flower.engine.simulation.supernode]
39
+ num = $num_clients
40
+
41
+ [flower.engine.simulation]
42
+ backend_config = { client_resources = { num_cpus = 8, num_gpus = 1.0 } }
flwr/cli/run/run.py CHANGED
@@ -17,12 +17,14 @@
17
17
  import sys
18
18
  from enum import Enum
19
19
  from logging import DEBUG
20
+ from pathlib import Path
20
21
  from typing import Optional
21
22
 
22
23
  import typer
23
24
  from typing_extensions import Annotated
24
25
 
25
26
  from flwr.cli import config_utils
27
+ from flwr.cli.build import build
26
28
  from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
27
29
  from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
28
30
  from flwr.common.logger import log
@@ -41,7 +43,10 @@ class Engine(str, Enum):
41
43
  def run(
42
44
  engine: Annotated[
43
45
  Optional[Engine],
44
- typer.Option(case_sensitive=False, help="The execution engine to run the app"),
46
+ typer.Option(
47
+ case_sensitive=False,
48
+ help="The engine to run FL with (currently only simulation is supported).",
49
+ ),
45
50
  ] = None,
46
51
  use_superexec: Annotated[
47
52
  bool,
@@ -49,10 +54,14 @@ def run(
49
54
  case_sensitive=False, help="Use this flag to use the new SuperExec API"
50
55
  ),
51
56
  ] = False,
57
+ directory: Annotated[
58
+ Optional[Path],
59
+ typer.Option(help="Path of the Flower project to run"),
60
+ ] = None,
52
61
  ) -> None:
53
62
  """Run Flower project."""
54
63
  if use_superexec:
55
- _start_superexec_run()
64
+ _start_superexec_run(directory)
56
65
  return
57
66
 
58
67
  typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
@@ -87,12 +96,16 @@ def run(
87
96
 
88
97
  if engine == Engine.SIMULATION:
89
98
  num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]
99
+ backend_config = config["flower"]["engine"]["simulation"].get(
100
+ "backend_config", None
101
+ )
90
102
 
91
103
  typer.secho("Starting run... ", fg=typer.colors.BLUE)
92
104
  _run_simulation(
93
105
  server_app_attr=server_app_ref,
94
106
  client_app_attr=client_app_ref,
95
107
  num_supernodes=num_supernodes,
108
+ backend_config=backend_config,
96
109
  )
97
110
  else:
98
111
  typer.secho(
@@ -102,7 +115,7 @@ def run(
102
115
  )
103
116
 
104
117
 
105
- def _start_superexec_run() -> None:
118
+ def _start_superexec_run(directory: Optional[Path]) -> None:
106
119
  def on_channel_state_change(channel_connectivity: str) -> None:
107
120
  """Log channel connectivity."""
108
121
  log(DEBUG, channel_connectivity)
@@ -117,5 +130,8 @@ def _start_superexec_run() -> None:
117
130
  channel.subscribe(on_channel_state_change)
118
131
  stub = ExecStub(channel)
119
132
 
120
- req = StartRunRequest()
121
- stub.StartRun(req)
133
+ fab_path = build(directory)
134
+
135
+ req = StartRunRequest(fab_file=Path(fab_path).read_bytes())
136
+ res = stub.StartRun(req)
137
+ typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)
flwr/client/__init__.py CHANGED
@@ -23,11 +23,13 @@ from .numpy_client import NumPyClient as NumPyClient
23
23
  from .supernode import run_client_app as run_client_app
24
24
  from .supernode import run_supernode as run_supernode
25
25
  from .typing import ClientFn as ClientFn
26
+ from .typing import ClientFnExt as ClientFnExt
26
27
 
27
28
  __all__ = [
28
29
  "Client",
29
30
  "ClientApp",
30
31
  "ClientFn",
32
+ "ClientFnExt",
31
33
  "NumPyClient",
32
34
  "mod",
33
35
  "run_client_app",
flwr/client/app.py CHANGED
@@ -26,7 +26,7 @@ from grpc import RpcError
26
26
 
27
27
  from flwr.client.client import Client
28
28
  from flwr.client.client_app import ClientApp, LoadClientAppError
29
- from flwr.client.typing import ClientFn
29
+ from flwr.client.typing import ClientFnExt
30
30
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
31
31
  from flwr.common.address import parse_address
32
32
  from flwr.common.constant import (
@@ -51,7 +51,7 @@ from .numpy_client import NumPyClient
51
51
 
52
52
 
53
53
  def _check_actionable_client(
54
- client: Optional[Client], client_fn: Optional[ClientFn]
54
+ client: Optional[Client], client_fn: Optional[ClientFnExt]
55
55
  ) -> None:
56
56
  if client_fn is None and client is None:
57
57
  raise ValueError(
@@ -72,7 +72,7 @@ def _check_actionable_client(
72
72
  def start_client(
73
73
  *,
74
74
  server_address: str,
75
- client_fn: Optional[ClientFn] = None,
75
+ client_fn: Optional[ClientFnExt] = None,
76
76
  client: Optional[Client] = None,
77
77
  grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
78
78
  root_certificates: Optional[Union[bytes, str]] = None,
@@ -92,7 +92,7 @@ def start_client(
92
92
  The IPv4 or IPv6 address of the server. If the Flower
93
93
  server runs on the same machine on port 8080, then `server_address`
94
94
  would be `"[::]:8080"`.
95
- client_fn : Optional[ClientFn]
95
+ client_fn : Optional[ClientFnExt]
96
96
  A callable that instantiates a Client. (default: None)
97
97
  client : Optional[flwr.client.Client]
98
98
  An implementation of the abstract base
@@ -136,7 +136,7 @@ def start_client(
136
136
 
137
137
  Starting an SSL-enabled gRPC client using system certificates:
138
138
 
139
- >>> def client_fn(cid: str):
139
+ >>> def client_fn(node_id: int, partition_id: Optional[int]):
140
140
  >>> return FlowerClient()
141
141
  >>>
142
142
  >>> start_client(
@@ -180,7 +180,7 @@ def _start_client_internal(
180
180
  *,
181
181
  server_address: str,
182
182
  load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
183
- client_fn: Optional[ClientFn] = None,
183
+ client_fn: Optional[ClientFnExt] = None,
184
184
  client: Optional[Client] = None,
185
185
  grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
186
186
  root_certificates: Optional[Union[bytes, str]] = None,
@@ -191,6 +191,7 @@ def _start_client_internal(
191
191
  ] = None,
192
192
  max_retries: Optional[int] = None,
193
193
  max_wait_time: Optional[float] = None,
194
+ partition_id: Optional[int] = None,
194
195
  ) -> None:
195
196
  """Start a Flower client node which connects to a Flower server.
196
197
 
@@ -202,7 +203,7 @@ def _start_client_internal(
202
203
  would be `"[::]:8080"`.
203
204
  load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None)
204
205
  A function that can be used to load a `ClientApp` instance.
205
- client_fn : Optional[ClientFn]
206
+ client_fn : Optional[ClientFnExt]
206
207
  A callable that instantiates a Client. (default: None)
207
208
  client : Optional[flwr.client.Client]
208
209
  An implementation of the abstract base
@@ -234,6 +235,9 @@ def _start_client_internal(
234
235
  The maximum duration before the client stops trying to
235
236
  connect to the server in case of connection error.
236
237
  If set to None, there is no limit to the total time.
238
+ partitioni_id: Optional[int] (default: None)
239
+ The data partition index associated with this node. Better suited for
240
+ prototyping purposes.
237
241
  """
238
242
  if insecure is None:
239
243
  insecure = root_certificates is None
@@ -244,7 +248,8 @@ def _start_client_internal(
244
248
  if client_fn is None:
245
249
  # Wrap `Client` instance in `client_fn`
246
250
  def single_client_factory(
247
- cid: str, # pylint: disable=unused-argument
251
+ node_id: int, # pylint: disable=unused-argument
252
+ partition_id: Optional[int], # pylint: disable=unused-argument
248
253
  ) -> Client:
249
254
  if client is None: # Added this to keep mypy happy
250
255
  raise ValueError(
@@ -293,7 +298,7 @@ def _start_client_internal(
293
298
  retry_invoker = RetryInvoker(
294
299
  wait_gen_factory=exponential,
295
300
  recoverable_exceptions=connection_error_type,
296
- max_tries=max_retries,
301
+ max_tries=max_retries + 1 if max_retries is not None else None,
297
302
  max_time=max_wait_time,
298
303
  on_giveup=lambda retry_state: (
299
304
  log(
@@ -309,7 +314,7 @@ def _start_client_internal(
309
314
  on_backoff=_on_backoff,
310
315
  )
311
316
 
312
- node_state = NodeState()
317
+ node_state = NodeState(partition_id=partition_id)
313
318
  # run_id -> (fab_id, fab_version)
314
319
  run_info: Dict[int, Tuple[str, str]] = {}
315
320
 
flwr/client/client_app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,19 +15,42 @@
15
15
  """Flower ClientApp."""
16
16
 
17
17
 
18
+ import inspect
18
19
  from typing import Callable, List, Optional
19
20
 
21
+ from flwr.client.client import Client
20
22
  from flwr.client.message_handler.message_handler import (
21
23
  handle_legacy_message_from_msgtype,
22
24
  )
23
25
  from flwr.client.mod.utils import make_ffn
24
- from flwr.client.typing import ClientFn, Mod
26
+ from flwr.client.typing import ClientFnExt, Mod
25
27
  from flwr.common import Context, Message, MessageType
26
- from flwr.common.logger import warn_preview_feature
28
+ from flwr.common.logger import warn_deprecated_feature, warn_preview_feature
27
29
 
28
30
  from .typing import ClientAppCallable
29
31
 
30
32
 
33
+ def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
34
+ client_fn_args = inspect.signature(client_fn).parameters
35
+
36
+ if not all(key in client_fn_args for key in ["node_id", "partition_id"]):
37
+ warn_deprecated_feature(
38
+ "`client_fn` now expects a signature `def client_fn(node_id: int, "
39
+ "partition_id: Optional[int])`.\nYou provided `client_fn` with signature: "
40
+ f"{dict(client_fn_args.items())}"
41
+ )
42
+
43
+ # Wrap depcreated client_fn inside a function with the expected signature
44
+ def adaptor_fn(
45
+ node_id: int, partition_id: Optional[int] # pylint: disable=unused-argument
46
+ ) -> Client:
47
+ return client_fn(str(partition_id)) # type: ignore
48
+
49
+ return adaptor_fn
50
+
51
+ return client_fn
52
+
53
+
31
54
  class ClientAppException(Exception):
32
55
  """Exception raised when an exception is raised while executing a ClientApp."""
33
56
 
@@ -48,7 +71,7 @@ class ClientApp:
48
71
  >>> class FlowerClient(NumPyClient):
49
72
  >>> # ...
50
73
  >>>
51
- >>> def client_fn(cid):
74
+ >>> def client_fn(node_id: int, partition_id: Optional[int]):
52
75
  >>> return FlowerClient().to_client()
53
76
  >>>
54
77
  >>> app = ClientApp(client_fn)
@@ -65,7 +88,7 @@ class ClientApp:
65
88
 
66
89
  def __init__(
67
90
  self,
68
- client_fn: Optional[ClientFn] = None, # Only for backward compatibility
91
+ client_fn: Optional[ClientFnExt] = None, # Only for backward compatibility
69
92
  mods: Optional[List[Mod]] = None,
70
93
  ) -> None:
71
94
  self._mods: List[Mod] = mods if mods is not None else []
@@ -74,6 +97,8 @@ class ClientApp:
74
97
  self._call: Optional[ClientAppCallable] = None
75
98
  if client_fn is not None:
76
99
 
100
+ client_fn = _inspect_maybe_adapt_client_fn_signature(client_fn)
101
+
77
102
  def ffn(
78
103
  message: Message,
79
104
  context: Context,
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2022 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
  """Client-side message handler."""
16
16
 
17
-
18
17
  from logging import WARN
19
18
  from typing import Optional, Tuple, cast
20
19
 
@@ -25,7 +24,7 @@ from flwr.client.client import (
25
24
  maybe_call_get_properties,
26
25
  )
27
26
  from flwr.client.numpy_client import NumPyClient
28
- from flwr.client.typing import ClientFn
27
+ from flwr.client.typing import ClientFnExt
29
28
  from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log
30
29
  from flwr.common.constant import MessageType, MessageTypeLegacy
31
30
  from flwr.common.recordset_compat import (
@@ -90,10 +89,10 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
90
89
 
91
90
 
92
91
  def handle_legacy_message_from_msgtype(
93
- client_fn: ClientFn, message: Message, context: Context
92
+ client_fn: ClientFnExt, message: Message, context: Context
94
93
  ) -> Message:
95
94
  """Handle legacy message in the inner most mod."""
96
- client = client_fn(str(message.metadata.partition_id))
95
+ client = client_fn(message.metadata.dst_node_id, context.partition_id)
97
96
 
98
97
  # Check if NumPyClient is returend
99
98
  if isinstance(client, NumPyClient):
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/client/mod/utils.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/client/node_state.py CHANGED
@@ -15,7 +15,7 @@
15
15
  """Node state."""
16
16
 
17
17
 
18
- from typing import Any, Dict
18
+ from typing import Any, Dict, Optional
19
19
 
20
20
  from flwr.common import Context, RecordSet
21
21
 
@@ -23,14 +23,17 @@ from flwr.common import Context, RecordSet
23
23
  class NodeState:
24
24
  """State of a node where client nodes execute runs."""
25
25
 
26
- def __init__(self) -> None:
26
+ def __init__(self, partition_id: Optional[int]) -> None:
27
27
  self._meta: Dict[str, Any] = {} # holds metadata about the node
28
28
  self.run_contexts: Dict[int, Context] = {}
29
+ self._partition_id = partition_id
29
30
 
30
31
  def register_context(self, run_id: int) -> None:
31
32
  """Register new run context for this node."""
32
33
  if run_id not in self.run_contexts:
33
- self.run_contexts[run_id] = Context(state=RecordSet())
34
+ self.run_contexts[run_id] = Context(
35
+ state=RecordSet(), partition_id=self._partition_id
36
+ )
34
37
 
35
38
  def retrieve_context(self, run_id: int) -> Context:
36
39
  """Get run context given a run_id."""
@@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None:
41
41
  expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}
42
42
 
43
43
  # NodeState
44
- node_state = NodeState()
44
+ node_state = NodeState(partition_id=None)
45
45
 
46
46
  for task in tasks:
47
47
  run_id = task.run_id
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -67,6 +67,7 @@ def run_supernode() -> None:
67
67
  authentication_keys=authentication_keys,
68
68
  max_retries=args.max_retries,
69
69
  max_wait_time=args.max_wait_time,
70
+ partition_id=args.partition_id,
70
71
  )
71
72
 
72
73
  # Graceful shutdown
@@ -267,7 +268,7 @@ def _parse_args_run_supernode() -> argparse.ArgumentParser:
267
268
  "--flwr-dir",
268
269
  default=None,
269
270
  help="""The path containing installed Flower Apps.
270
- By default, this value isequal to:
271
+ By default, this value is equal to:
271
272
 
272
273
  - `$FLWR_HOME/` if `$FLWR_HOME` is defined
273
274
  - `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
@@ -344,8 +345,8 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
344
345
  "--max-retries",
345
346
  type=int,
346
347
  default=None,
347
- help="The maximum number of times the client will try to connect to the"
348
- "server before giving up in case of a connection error. By default,"
348
+ help="The maximum number of times the client will try to reconnect to the"
349
+ "SuperLink before giving up in case of a connection error. By default,"
349
350
  "it is set to None, meaning there is no limit to the number of tries.",
350
351
  )
351
352
  parser.add_argument(
@@ -353,7 +354,7 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
353
354
  type=float,
354
355
  default=None,
355
356
  help="The maximum duration before the client stops trying to"
356
- "connect to the server in case of connection error. By default, it"
357
+ "connect to the SuperLink in case of connection error. By default, it"
357
358
  "is set to None, meaning there is no limit to the total time.",
358
359
  )
359
360
  parser.add_argument(
@@ -373,6 +374,13 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
373
374
  type=str,
374
375
  help="The SuperNode's public key (as a path str) to enable authentication.",
375
376
  )
377
+ parser.add_argument(
378
+ "--partition-id",
379
+ type=int,
380
+ help="The data partition index associated with this SuperNode. Better suited "
381
+ "for prototyping purposes where a SuperNode might only load a fraction of an "
382
+ "artificially partitioned dataset (e.g. using `flwr-datasets`)",
383
+ )
376
384
 
377
385
 
378
386
  def _try_setup_client_authentication(
flwr/client/typing.py CHANGED
@@ -15,7 +15,7 @@
15
15
  """Custom types for Flower clients."""
16
16
 
17
17
 
18
- from typing import Callable
18
+ from typing import Callable, Optional
19
19
 
20
20
  from flwr.common import Context, Message
21
21
 
@@ -23,6 +23,7 @@ from .client import Client as Client
23
23
 
24
24
  # Compatibility
25
25
  ClientFn = Callable[[str], Client]
26
+ ClientFnExt = Callable[[int, Optional[int]], Client]
26
27
 
27
28
  ClientAppCallable = Callable[[Message, Context], Message]
28
29
  Mod = Callable[[Message, Context, ClientAppCallable], Message]
flwr/common/address.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
flwr/common/config.py CHANGED
@@ -24,14 +24,16 @@ from flwr.cli.config_utils import validate_fields
24
24
  from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
25
25
 
26
26
 
27
- def get_flwr_dir() -> Path:
27
+ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
28
28
  """Return the Flower home directory based on env variables."""
29
- return Path(
30
- os.getenv(
31
- FLWR_HOME,
32
- f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr",
29
+ if provided_path is None or not Path(provided_path).is_dir():
30
+ return Path(
31
+ os.getenv(
32
+ FLWR_HOME,
33
+ f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr",
34
+ )
33
35
  )
34
- )
36
+ return Path(provided_path).absolute()
35
37
 
36
38
 
37
39
  def get_project_dir(
flwr/common/constant.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 Flower Labs GmbH. All Rights Reserved.
1
+ # Copyright 2023 Flower Labs GmbH. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -46,6 +46,9 @@ PING_BASE_MULTIPLIER = 0.8
46
46
  PING_RANDOM_RANGE = (-0.1, 0.1)
47
47
  PING_MAX_INTERVAL = 1e300
48
48
 
49
+ # IDs
50
+ RUN_ID_NUM_BYTES = 8
51
+ NODE_ID_NUM_BYTES = 8
49
52
  GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
50
53
  GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
51
54
 
flwr/common/context.py CHANGED
@@ -16,13 +16,14 @@
16
16
 
17
17
 
18
18
  from dataclasses import dataclass
19
+ from typing import Optional
19
20
 
20
21
  from .record import RecordSet
21
22
 
22
23
 
23
24
  @dataclass
24
25
  class Context:
25
- """State of your run.
26
+ """Context of your run.
26
27
 
27
28
  Parameters
28
29
  ----------
@@ -33,6 +34,15 @@ class Context:
33
34
  executing mods. It can also be used as a memory to access
34
35
  at different points during the lifecycle of this entity (e.g. across
35
36
  multiple rounds)
37
+ partition_id : Optional[int] (default: None)
38
+ An index that specifies the data partition that the ClientApp using this Context
39
+ object should make use of. Setting this attribute is better suited for
40
+ simulation or proto typing setups.
36
41
  """
37
42
 
38
43
  state: RecordSet
44
+ partition_id: Optional[int]
45
+
46
+ def __init__(self, state: RecordSet, partition_id: Optional[int] = None) -> None:
47
+ self.state = state
48
+ self.partition_id = partition_id