flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.11.0.dev20240724__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (99) hide show
  1. flwr/cli/build.py +16 -2
  2. flwr/cli/config_utils.py +47 -27
  3. flwr/cli/install.py +17 -1
  4. flwr/cli/new/new.py +32 -21
  5. flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +15 -5
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +36 -13
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +16 -5
  10. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +25 -5
  12. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +22 -19
  13. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
  14. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
  15. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  16. flwr/cli/new/templates/app/code/server.jax.py.tpl +16 -8
  17. flwr/cli/new/templates/app/code/server.mlx.py.tpl +12 -7
  18. flwr/cli/new/templates/app/code/server.numpy.py.tpl +16 -8
  19. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
  20. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -10
  21. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
  22. flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +14 -2
  23. flwr/cli/new/templates/app/code/task.mlx.py.tpl +14 -2
  24. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -3
  25. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +13 -1
  26. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
  27. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  28. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
  29. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
  30. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
  31. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
  32. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
  33. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
  34. flwr/cli/run/run.py +133 -54
  35. flwr/client/app.py +56 -24
  36. flwr/client/client_app.py +28 -8
  37. flwr/client/grpc_adapter_client/connection.py +3 -2
  38. flwr/client/grpc_client/connection.py +3 -2
  39. flwr/client/grpc_rere_client/connection.py +17 -6
  40. flwr/client/message_handler/message_handler.py +1 -1
  41. flwr/client/node_state.py +59 -12
  42. flwr/client/node_state_tests.py +4 -3
  43. flwr/client/rest_client/connection.py +19 -8
  44. flwr/client/supernode/app.py +39 -39
  45. flwr/client/typing.py +2 -2
  46. flwr/common/config.py +92 -2
  47. flwr/common/constant.py +3 -0
  48. flwr/common/context.py +24 -9
  49. flwr/common/logger.py +25 -0
  50. flwr/common/object_ref.py +84 -21
  51. flwr/common/serde.py +45 -0
  52. flwr/common/telemetry.py +17 -0
  53. flwr/common/typing.py +5 -0
  54. flwr/proto/common_pb2.py +36 -0
  55. flwr/proto/common_pb2.pyi +121 -0
  56. flwr/proto/common_pb2_grpc.py +4 -0
  57. flwr/proto/common_pb2_grpc.pyi +4 -0
  58. flwr/proto/driver_pb2.py +24 -19
  59. flwr/proto/driver_pb2.pyi +21 -1
  60. flwr/proto/exec_pb2.py +20 -11
  61. flwr/proto/exec_pb2.pyi +41 -1
  62. flwr/proto/run_pb2.py +12 -7
  63. flwr/proto/run_pb2.pyi +22 -1
  64. flwr/proto/task_pb2.py +7 -8
  65. flwr/server/__init__.py +2 -0
  66. flwr/server/compat/legacy_context.py +5 -4
  67. flwr/server/driver/grpc_driver.py +82 -140
  68. flwr/server/run_serverapp.py +40 -18
  69. flwr/server/server_app.py +56 -10
  70. flwr/server/serverapp_components.py +52 -0
  71. flwr/server/superlink/driver/driver_servicer.py +18 -3
  72. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  73. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  74. flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
  75. flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
  76. flwr/server/superlink/fleet/vce/vce_api.py +149 -117
  77. flwr/server/superlink/state/in_memory_state.py +11 -3
  78. flwr/server/superlink/state/sqlite_state.py +23 -8
  79. flwr/server/superlink/state/state.py +7 -2
  80. flwr/server/typing.py +2 -0
  81. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  82. flwr/simulation/__init__.py +1 -1
  83. flwr/simulation/app.py +4 -3
  84. flwr/simulation/ray_transport/ray_actor.py +15 -19
  85. flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
  86. flwr/simulation/run_simulation.py +269 -70
  87. flwr/superexec/app.py +17 -11
  88. flwr/superexec/deployment.py +111 -35
  89. flwr/superexec/exec_grpc.py +5 -1
  90. flwr/superexec/exec_servicer.py +6 -1
  91. flwr/superexec/executor.py +21 -0
  92. flwr/superexec/simulation.py +181 -0
  93. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/METADATA +3 -2
  94. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/RECORD +97 -91
  95. flwr/cli/new/templates/app/code/server.hf.py.tpl +0 -17
  96. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +0 -37
  97. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/LICENSE +0 -0
  98. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/WHEEL +0 -0
  99. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/entry_points.txt +0 -0
@@ -40,7 +40,12 @@ from flwr.common.grpc import create_channel
40
40
  from flwr.common.logger import log
41
41
  from flwr.common.message import Message, Metadata
42
42
  from flwr.common.retry_invoker import RetryInvoker
43
- from flwr.common.serde import message_from_taskins, message_to_taskres
43
+ from flwr.common.serde import (
44
+ message_from_taskins,
45
+ message_to_taskres,
46
+ user_config_from_proto,
47
+ )
48
+ from flwr.common.typing import Run
44
49
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
45
50
  CreateNodeRequest,
46
51
  DeleteNodeRequest,
@@ -78,9 +83,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
78
83
  Tuple[
79
84
  Callable[[], Optional[Message]],
80
85
  Callable[[Message], None],
86
+ Optional[Callable[[], Optional[int]]],
81
87
  Optional[Callable[[], None]],
82
- Optional[Callable[[], None]],
83
- Optional[Callable[[int], Tuple[str, str]]],
88
+ Optional[Callable[[int], Run]],
84
89
  ]
85
90
  ]:
86
91
  """Primitives for request/response-based interaction with a server.
@@ -175,7 +180,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
175
180
  if not ping_stop_event.is_set():
176
181
  ping_stop_event.wait(next_interval)
177
182
 
178
- def create_node() -> None:
183
+ def create_node() -> Optional[int]:
179
184
  """Set create_node."""
180
185
  # Call FleetAPI
181
186
  create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
@@ -188,6 +193,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
188
193
  nonlocal node, ping_thread
189
194
  node = cast(Node, create_node_response.node)
190
195
  ping_thread = start_ping_loop(ping, ping_stop_event)
196
+ return node.node_id
191
197
 
192
198
  def delete_node() -> None:
193
199
  """Set delete_node."""
@@ -266,7 +272,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
266
272
  # Cleanup
267
273
  metadata = None
268
274
 
269
- def get_run(run_id: int) -> Tuple[str, str]:
275
+ def get_run(run_id: int) -> Run:
270
276
  # Call FleetAPI
271
277
  get_run_request = GetRunRequest(run_id=run_id)
272
278
  get_run_response: GetRunResponse = retry_invoker.invoke(
@@ -275,7 +281,12 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
275
281
  )
276
282
 
277
283
  # Return fab_id and fab_version
278
- return get_run_response.run.fab_id, get_run_response.run.fab_version
284
+ return Run(
285
+ run_id,
286
+ get_run_response.run.fab_id,
287
+ get_run_response.run.fab_version,
288
+ user_config_from_proto(get_run_response.run.override_config),
289
+ )
279
290
 
280
291
  try:
281
292
  # Yield methods
@@ -92,7 +92,7 @@ def handle_legacy_message_from_msgtype(
92
92
  client_fn: ClientFnExt, message: Message, context: Context
93
93
  ) -> Message:
94
94
  """Handle legacy message in the inner most mod."""
95
- client = client_fn(message.metadata.dst_node_id, context.partition_id)
95
+ client = client_fn(context)
96
96
 
97
97
  # Check if NumPyClient is returend
98
98
  if isinstance(client, NumPyClient):
flwr/client/node_state.py CHANGED
@@ -15,30 +15,72 @@
15
15
  """Node state."""
16
16
 
17
17
 
18
- from typing import Any, Dict, Optional
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+ from typing import Dict, Optional
19
21
 
20
22
  from flwr.common import Context, RecordSet
23
+ from flwr.common.config import get_fused_config, get_fused_config_from_dir
24
+ from flwr.common.typing import Run, UserConfig
25
+
26
+
27
+ @dataclass()
28
+ class RunInfo:
29
+ """Contains the Context and initial run_config of a Run."""
30
+
31
+ context: Context
32
+ initial_run_config: UserConfig
21
33
 
22
34
 
23
35
  class NodeState:
24
36
  """State of a node where client nodes execute runs."""
25
37
 
26
- def __init__(self, partition_id: Optional[int]) -> None:
27
- self._meta: Dict[str, Any] = {} # holds metadata about the node
28
- self.run_contexts: Dict[int, Context] = {}
29
- self._partition_id = partition_id
38
+ def __init__(
39
+ self,
40
+ node_id: int,
41
+ node_config: UserConfig,
42
+ ) -> None:
43
+ self.node_id = node_id
44
+ self.node_config = node_config
45
+ self.run_infos: Dict[int, RunInfo] = {}
30
46
 
31
- def register_context(self, run_id: int) -> None:
47
+ def register_context(
48
+ self,
49
+ run_id: int,
50
+ run: Optional[Run] = None,
51
+ flwr_path: Optional[Path] = None,
52
+ app_dir: Optional[str] = None,
53
+ ) -> None:
32
54
  """Register new run context for this node."""
33
- if run_id not in self.run_contexts:
34
- self.run_contexts[run_id] = Context(
35
- state=RecordSet(), partition_id=self._partition_id
55
+ if run_id not in self.run_infos:
56
+ initial_run_config = {}
57
+ if app_dir:
58
+ # Load from app directory
59
+ app_path = Path(app_dir)
60
+ if app_path.is_dir():
61
+ override_config = run.override_config if run else {}
62
+ initial_run_config = get_fused_config_from_dir(
63
+ app_path, override_config
64
+ )
65
+ else:
66
+ raise ValueError("The specified `app_dir` must be a directory.")
67
+ else:
68
+ # Load from .fab
69
+ initial_run_config = get_fused_config(run, flwr_path) if run else {}
70
+ self.run_infos[run_id] = RunInfo(
71
+ initial_run_config=initial_run_config,
72
+ context=Context(
73
+ node_id=self.node_id,
74
+ node_config=self.node_config,
75
+ state=RecordSet(),
76
+ run_config=initial_run_config.copy(),
77
+ ),
36
78
  )
37
79
 
38
80
  def retrieve_context(self, run_id: int) -> Context:
39
81
  """Get run context given a run_id."""
40
- if run_id in self.run_contexts:
41
- return self.run_contexts[run_id]
82
+ if run_id in self.run_infos:
83
+ return self.run_infos[run_id].context
42
84
 
43
85
  raise RuntimeError(
44
86
  f"Context for run_id={run_id} doesn't exist."
@@ -48,4 +90,9 @@ class NodeState:
48
90
 
49
91
  def update_context(self, run_id: int, context: Context) -> None:
50
92
  """Update run context."""
51
- self.run_contexts[run_id] = context
93
+ if context.run_config != self.run_infos[run_id].initial_run_config:
94
+ raise ValueError(
95
+ "The `run_config` field of the `Context` object cannot be "
96
+ f"modified (run_id: {run_id})."
97
+ )
98
+ self.run_infos[run_id].context = context
@@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None:
41
41
  expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}
42
42
 
43
43
  # NodeState
44
- node_state = NodeState(partition_id=None)
44
+ node_state = NodeState(node_id=0, node_config={})
45
45
 
46
46
  for task in tasks:
47
47
  run_id = task.run_id
@@ -59,7 +59,8 @@ def test_multirun_in_node_state() -> None:
59
59
  node_state.update_context(run_id=run_id, context=updated_state)
60
60
 
61
61
  # Verify values
62
- for run_id, context in node_state.run_contexts.items():
62
+ for run_id, run_info in node_state.run_infos.items():
63
63
  assert (
64
- context.state.configs_records["counter"]["count"] == expected_values[run_id]
64
+ run_info.context.state.configs_records["counter"]["count"]
65
+ == expected_values[run_id]
65
66
  )
@@ -40,7 +40,12 @@ from flwr.common.constant import (
40
40
  from flwr.common.logger import log
41
41
  from flwr.common.message import Message, Metadata
42
42
  from flwr.common.retry_invoker import RetryInvoker
43
- from flwr.common.serde import message_from_taskins, message_to_taskres
43
+ from flwr.common.serde import (
44
+ message_from_taskins,
45
+ message_to_taskres,
46
+ user_config_from_proto,
47
+ )
48
+ from flwr.common.typing import Run
44
49
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
45
50
  CreateNodeRequest,
46
51
  CreateNodeResponse,
@@ -89,9 +94,9 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
89
94
  Tuple[
90
95
  Callable[[], Optional[Message]],
91
96
  Callable[[Message], None],
97
+ Optional[Callable[[], Optional[int]]],
92
98
  Optional[Callable[[], None]],
93
- Optional[Callable[[], None]],
94
- Optional[Callable[[int], Tuple[str, str]]],
99
+ Optional[Callable[[int], Run]],
95
100
  ]
96
101
  ]:
97
102
  """Primitives for request/response-based interaction with a server.
@@ -236,19 +241,20 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
236
241
  if not ping_stop_event.is_set():
237
242
  ping_stop_event.wait(next_interval)
238
243
 
239
- def create_node() -> None:
244
+ def create_node() -> Optional[int]:
240
245
  """Set create_node."""
241
246
  req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
242
247
 
243
248
  # Send the request
244
249
  res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
245
250
  if res is None:
246
- return
251
+ return None
247
252
 
248
253
  # Remember the node and the ping-loop thread
249
254
  nonlocal node, ping_thread
250
255
  node = res.node
251
256
  ping_thread = start_ping_loop(ping, ping_stop_event)
257
+ return node.node_id
252
258
 
253
259
  def delete_node() -> None:
254
260
  """Set delete_node."""
@@ -344,16 +350,21 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
344
350
  res.results, # pylint: disable=no-member
345
351
  )
346
352
 
347
- def get_run(run_id: int) -> Tuple[str, str]:
353
+ def get_run(run_id: int) -> Run:
348
354
  # Construct the request
349
355
  req = GetRunRequest(run_id=run_id)
350
356
 
351
357
  # Send the request
352
358
  res = _request(req, GetRunResponse, PATH_GET_RUN)
353
359
  if res is None:
354
- return "", ""
360
+ return Run(run_id, "", "", {})
355
361
 
356
- return res.run.fab_id, res.run.fab_version
362
+ return Run(
363
+ run_id,
364
+ res.run.fab_id,
365
+ res.run.fab_version,
366
+ user_config_from_proto(res.run.override_config),
367
+ )
357
368
 
358
369
  try:
359
370
  # Yield methods
@@ -29,7 +29,12 @@ from cryptography.hazmat.primitives.serialization import (
29
29
 
30
30
  from flwr.client.client_app import ClientApp, LoadClientAppError
31
31
  from flwr.common import EventType, event
32
- from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
32
+ from flwr.common.config import (
33
+ get_flwr_dir,
34
+ get_project_config,
35
+ get_project_dir,
36
+ parse_config_args,
37
+ )
33
38
  from flwr.common.constant import (
34
39
  TRANSPORT_TYPE_GRPC_ADAPTER,
35
40
  TRANSPORT_TYPE_GRPC_RERE,
@@ -55,7 +60,12 @@ def run_supernode() -> None:
55
60
  _warn_deprecated_server_arg(args)
56
61
 
57
62
  root_certificates = _get_certificates(args)
58
- load_fn = _get_load_client_app_fn(args, multi_app=True)
63
+ load_fn = _get_load_client_app_fn(
64
+ default_app_ref=getattr(args, "client-app"),
65
+ project_dir=args.dir,
66
+ flwr_dir=args.flwr_dir,
67
+ multi_app=True,
68
+ )
59
69
  authentication_keys = _try_setup_client_authentication(args)
60
70
 
61
71
  _start_client_internal(
@@ -67,7 +77,8 @@ def run_supernode() -> None:
67
77
  authentication_keys=authentication_keys,
68
78
  max_retries=args.max_retries,
69
79
  max_wait_time=args.max_wait_time,
70
- partition_id=args.partition_id,
80
+ node_config=parse_config_args([args.node_config]),
81
+ flwr_path=get_flwr_dir(args.flwr_dir),
71
82
  )
72
83
 
73
84
  # Graceful shutdown
@@ -87,11 +98,16 @@ def run_client_app() -> None:
87
98
  _warn_deprecated_server_arg(args)
88
99
 
89
100
  root_certificates = _get_certificates(args)
90
- load_fn = _get_load_client_app_fn(args, multi_app=False)
101
+ load_fn = _get_load_client_app_fn(
102
+ default_app_ref=getattr(args, "client-app"),
103
+ project_dir=args.dir,
104
+ multi_app=False,
105
+ )
91
106
  authentication_keys = _try_setup_client_authentication(args)
92
107
 
93
108
  _start_client_internal(
94
109
  server_address=args.superlink,
110
+ node_config=parse_config_args([args.node_config]),
95
111
  load_client_app_fn=load_fn,
96
112
  transport=args.transport,
97
113
  root_certificates=root_certificates,
@@ -159,7 +175,10 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
159
175
 
160
176
 
161
177
  def _get_load_client_app_fn(
162
- args: argparse.Namespace, multi_app: bool
178
+ default_app_ref: str,
179
+ project_dir: str,
180
+ multi_app: bool,
181
+ flwr_dir: Optional[str] = None,
163
182
  ) -> Callable[[str, str], ClientApp]:
164
183
  """Get the load_client_app_fn function.
165
184
 
@@ -170,34 +189,21 @@ def _get_load_client_app_fn(
170
189
  If `multi_app` is False, it ignores `fab_id` and `fab_version` and
171
190
  loads a default ClientApp.
172
191
  """
173
- # Find the Flower directory containing Flower Apps (only for multi-app)
174
- flwr_dir = Path("")
175
- if "flwr_dir" in args:
176
- if args.flwr_dir is None:
177
- flwr_dir = get_flwr_dir()
178
- else:
179
- flwr_dir = Path(args.flwr_dir).absolute()
180
-
181
- sys.path.insert(0, str(flwr_dir.absolute()))
182
-
183
- default_app_ref: str = getattr(args, "client-app")
184
-
185
192
  if not multi_app:
186
193
  log(
187
194
  DEBUG,
188
195
  "Flower SuperNode will load and validate ClientApp `%s`",
189
- getattr(args, "client-app"),
196
+ default_app_ref,
190
197
  )
191
- valid, error_msg = validate(default_app_ref)
198
+
199
+ valid, error_msg = validate(default_app_ref, project_dir=project_dir)
192
200
  if not valid and error_msg:
193
201
  raise LoadClientAppError(error_msg) from None
194
202
 
195
203
  def _load(fab_id: str, fab_version: str) -> ClientApp:
204
+ runtime_project_dir = Path(project_dir).absolute()
196
205
  # If multi-app feature is disabled
197
206
  if not multi_app:
198
- # Get sys path to be inserted
199
- sys_path = Path(args.dir).absolute()
200
-
201
207
  # Set app reference
202
208
  client_app_ref = default_app_ref
203
209
  # If multi-app feature is enabled but the fab id is not specified
@@ -208,27 +214,21 @@ def _get_load_client_app_fn(
208
214
  ) from None
209
215
 
210
216
  log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
211
- # Get sys path to be inserted
212
- sys_path = Path(args.dir).absolute()
213
217
 
214
218
  # Set app reference
215
219
  client_app_ref = default_app_ref
216
220
  # If multi-app feature is enabled
217
221
  else:
218
222
  try:
219
- project_dir = get_project_dir(fab_id, fab_version, flwr_dir)
220
- config = get_project_config(project_dir)
223
+ runtime_project_dir = get_project_dir(
224
+ fab_id, fab_version, get_flwr_dir(flwr_dir)
225
+ )
226
+ config = get_project_config(runtime_project_dir)
221
227
  except Exception as e:
222
228
  raise LoadClientAppError("Failed to load ClientApp") from e
223
229
 
224
- # Get sys path to be inserted
225
- sys_path = Path(project_dir).absolute()
226
-
227
230
  # Set app reference
228
- client_app_ref = config["flower"]["components"]["clientapp"]
229
-
230
- # Set sys.path
231
- sys.path.insert(0, str(sys_path))
231
+ client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
232
232
 
233
233
  # Load ClientApp
234
234
  log(
@@ -236,7 +236,7 @@ def _get_load_client_app_fn(
236
236
  "Loading ClientApp `%s`",
237
237
  client_app_ref,
238
238
  )
239
- client_app = load_app(client_app_ref, LoadClientAppError, sys_path)
239
+ client_app = load_app(client_app_ref, LoadClientAppError, runtime_project_dir)
240
240
 
241
241
  if not isinstance(client_app, ClientApp):
242
242
  raise LoadClientAppError(
@@ -375,11 +375,11 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
375
375
  help="The SuperNode's public key (as a path str) to enable authentication.",
376
376
  )
377
377
  parser.add_argument(
378
- "--partition-id",
379
- type=int,
380
- help="The data partition index associated with this SuperNode. Better suited "
381
- "for prototyping purposes where a SuperNode might only load a fraction of an "
382
- "artificially partitioned dataset (e.g. using `flwr-datasets`)",
378
+ "--node-config",
379
+ type=str,
380
+ help="A comma separated list of key/value pairs (separated by `=`) to "
381
+ "configure the SuperNode. "
382
+ "E.g. --node-config 'key1=\"value1\",partition-id=0,num-partitions=100'",
383
383
  )
384
384
 
385
385
 
flwr/client/typing.py CHANGED
@@ -15,7 +15,7 @@
15
15
  """Custom types for Flower clients."""
16
16
 
17
17
 
18
- from typing import Callable, Optional
18
+ from typing import Callable
19
19
 
20
20
  from flwr.common import Context, Message
21
21
 
@@ -23,7 +23,7 @@ from .client import Client as Client
23
23
 
24
24
  # Compatibility
25
25
  ClientFn = Callable[[str], Client]
26
- ClientFnExt = Callable[[int, Optional[int]], Client]
26
+ ClientFnExt = Callable[[Context], Client]
27
27
 
28
28
  ClientAppCallable = Callable[[Message, Context], Message]
29
29
  Mod = Callable[[Message, Context, ClientAppCallable], Message]
flwr/common/config.py CHANGED
@@ -16,12 +16,13 @@
16
16
 
17
17
  import os
18
18
  from pathlib import Path
19
- from typing import Any, Dict, Optional, Union
19
+ from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args
20
20
 
21
21
  import tomli
22
22
 
23
23
  from flwr.cli.config_utils import validate_fields
24
24
  from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
25
+ from flwr.common.typing import Run, UserConfig, UserConfigValue
25
26
 
26
27
 
27
28
  def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
@@ -30,7 +31,7 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
30
31
  return Path(
31
32
  os.getenv(
32
33
  FLWR_HOME,
33
- f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr",
34
+ Path(f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}") / ".flwr",
34
35
  )
35
36
  )
36
37
  return Path(provided_path).absolute()
@@ -71,3 +72,92 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]:
71
72
  )
72
73
 
73
74
  return config
75
+
76
+
77
+ def _fuse_dicts(
78
+ main_dict: UserConfig,
79
+ override_dict: UserConfig,
80
+ ) -> UserConfig:
81
+ fused_dict = main_dict.copy()
82
+
83
+ for key, value in override_dict.items():
84
+ if key in main_dict:
85
+ fused_dict[key] = value
86
+
87
+ return fused_dict
88
+
89
+
90
+ def get_fused_config_from_dir(
91
+ project_dir: Path, override_config: UserConfig
92
+ ) -> UserConfig:
93
+ """Merge the overrides from a given dict with the config from a Flower App."""
94
+ default_config = get_project_config(project_dir)["tool"]["flwr"]["app"].get(
95
+ "config", {}
96
+ )
97
+ flat_default_config = flatten_dict(default_config)
98
+
99
+ return _fuse_dicts(flat_default_config, override_config)
100
+
101
+
102
+ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
103
+ """Merge the overrides from a `Run` with the config from a FAB.
104
+
105
+ Get the config using the fab_id and the fab_version, remove the nesting by adding
106
+ the nested keys as prefixes separated by dots, and fuse it with the override dict.
107
+ """
108
+ if not run.fab_id or not run.fab_version:
109
+ return {}
110
+
111
+ project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir)
112
+
113
+ return get_fused_config_from_dir(project_dir, run.override_config)
114
+
115
+
116
+ def flatten_dict(
117
+ raw_dict: Optional[Dict[str, Any]], parent_key: str = ""
118
+ ) -> UserConfig:
119
+ """Flatten dict by joining nested keys with a given separator."""
120
+ if raw_dict is None:
121
+ return {}
122
+
123
+ items: List[Tuple[str, UserConfigValue]] = []
124
+ separator: str = "."
125
+ for k, v in raw_dict.items():
126
+ new_key = f"{parent_key}{separator}{k}" if parent_key else k
127
+ if isinstance(v, dict):
128
+ items.extend(flatten_dict(v, parent_key=new_key).items())
129
+ elif isinstance(v, get_args(UserConfigValue)):
130
+ items.append((new_key, cast(UserConfigValue, v)))
131
+ else:
132
+ raise ValueError(
133
+ f"The value for key {k} needs to be of type `int`, `float`, "
134
+ "`bool, `str`, or a `dict` of those.",
135
+ )
136
+ return dict(items)
137
+
138
+
139
+ def parse_config_args(
140
+ config: Optional[List[str]],
141
+ separator: str = ",",
142
+ ) -> UserConfig:
143
+ """Parse separator separated list of key-value pairs separated by '='."""
144
+ overrides: UserConfig = {}
145
+
146
+ if config is None:
147
+ return overrides
148
+
149
+ for config_line in config:
150
+ if config_line:
151
+ overrides_list = config_line.split(separator)
152
+ if (
153
+ len(overrides_list) == 1
154
+ and "=" not in overrides_list
155
+ and overrides_list[0].endswith(".toml")
156
+ ):
157
+ with Path(overrides_list[0]).open("rb") as config_file:
158
+ overrides = flatten_dict(tomli.load(config_file))
159
+ else:
160
+ toml_str = "\n".join(overrides_list)
161
+ overrides.update(tomli.loads(toml_str))
162
+
163
+ return overrides
flwr/common/constant.py CHANGED
@@ -57,6 +57,9 @@ APP_DIR = "apps"
57
57
  FAB_CONFIG_FILE = "pyproject.toml"
58
58
  FLWR_HOME = "FLWR_HOME"
59
59
 
60
+ # Constants entries in Node config for Simulation
61
+ PARTITION_ID_KEY = "partition-id"
62
+ NUM_PARTITIONS_KEY = "num-partitions"
60
63
 
61
64
  GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
62
65
  GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
flwr/common/context.py CHANGED
@@ -16,9 +16,9 @@
16
16
 
17
17
 
18
18
  from dataclasses import dataclass
19
- from typing import Optional
20
19
 
21
20
  from .record import RecordSet
21
+ from .typing import UserConfig
22
22
 
23
23
 
24
24
  @dataclass
@@ -27,6 +27,11 @@ class Context:
27
27
 
28
28
  Parameters
29
29
  ----------
30
+ node_id : int
31
+ The ID that identifies the node.
32
+ node_config : UserConfig
33
+ A config (key/value mapping) unique to the node and independent of the
34
+ `run_config`. This config persists across all runs this node participates in.
30
35
  state : RecordSet
31
36
  Holds records added by the entity in a given run and that will stay local.
32
37
  This means that the data it holds will never leave the system it's running from.
@@ -34,15 +39,25 @@ class Context:
34
39
  executing mods. It can also be used as a memory to access
35
40
  at different points during the lifecycle of this entity (e.g. across
36
41
  multiple rounds)
37
- partition_id : Optional[int] (default: None)
38
- An index that specifies the data partition that the ClientApp using this Context
39
- object should make use of. Setting this attribute is better suited for
40
- simulation or proto typing setups.
42
+ run_config : UserConfig
43
+ A config (key/value mapping) held by the entity in a given run and that will
44
+ stay local. It can be used at any point during the lifecycle of this entity
45
+ (e.g. across multiple rounds)
41
46
  """
42
47
 
48
+ node_id: int
49
+ node_config: UserConfig
43
50
  state: RecordSet
44
- partition_id: Optional[int]
45
-
46
- def __init__(self, state: RecordSet, partition_id: Optional[int] = None) -> None:
51
+ run_config: UserConfig
52
+
53
+ def __init__( # pylint: disable=too-many-arguments
54
+ self,
55
+ node_id: int,
56
+ node_config: UserConfig,
57
+ state: RecordSet,
58
+ run_config: UserConfig,
59
+ ) -> None:
60
+ self.node_id = node_id
61
+ self.node_config = node_config
47
62
  self.state = state
48
- self.partition_id = partition_id
63
+ self.run_config = run_config
flwr/common/logger.py CHANGED
@@ -197,6 +197,31 @@ def warn_deprecated_feature(name: str) -> None:
197
197
  )
198
198
 
199
199
 
200
+ def warn_deprecated_feature_with_example(
201
+ deprecation_message: str, example_message: str, code_example: str
202
+ ) -> None:
203
+ """Warn if a feature is deprecated and show code example."""
204
+ log(
205
+ WARN,
206
+ """DEPRECATED FEATURE: %s
207
+
208
+ Check the following `FEATURE UPDATE` warning message for the preferred
209
+ new mechanism to use this feature in Flower.
210
+ """,
211
+ deprecation_message,
212
+ )
213
+ log(
214
+ WARN,
215
+ """FEATURE UPDATE: %s
216
+ ------------------------------------------------------------
217
+ %s
218
+ ------------------------------------------------------------
219
+ """,
220
+ example_message,
221
+ code_example,
222
+ )
223
+
224
+
200
225
  def warn_unsupported_feature(name: str) -> None:
201
226
  """Warn the user when they use an unsupported feature."""
202
227
  log(