flwr-nightly 1.10.0.dev20240624__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (95) hide show
  1. flwr/cli/build.py +18 -4
  2. flwr/cli/config_utils.py +36 -14
  3. flwr/cli/install.py +17 -1
  4. flwr/cli/new/new.py +31 -20
  5. flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
  10. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
  12. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
  13. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
  14. flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
  15. flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
  16. flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
  17. flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
  18. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
  19. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
  20. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
  21. flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
  22. flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
  23. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
  25. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
  26. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
  27. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
  28. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
  29. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
  30. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
  31. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
  32. flwr/cli/run/run.py +135 -51
  33. flwr/client/__init__.py +2 -0
  34. flwr/client/app.py +63 -26
  35. flwr/client/client_app.py +49 -4
  36. flwr/client/grpc_adapter_client/connection.py +3 -2
  37. flwr/client/grpc_client/connection.py +3 -2
  38. flwr/client/grpc_rere_client/connection.py +17 -6
  39. flwr/client/message_handler/message_handler.py +3 -4
  40. flwr/client/node_state.py +60 -10
  41. flwr/client/node_state_tests.py +4 -3
  42. flwr/client/rest_client/connection.py +19 -8
  43. flwr/client/supernode/app.py +60 -21
  44. flwr/client/typing.py +1 -0
  45. flwr/common/config.py +87 -2
  46. flwr/common/constant.py +6 -0
  47. flwr/common/context.py +26 -1
  48. flwr/common/logger.py +38 -0
  49. flwr/common/message.py +0 -17
  50. flwr/common/serde.py +45 -0
  51. flwr/common/telemetry.py +17 -0
  52. flwr/common/typing.py +5 -0
  53. flwr/proto/common_pb2.py +36 -0
  54. flwr/proto/common_pb2.pyi +121 -0
  55. flwr/proto/common_pb2_grpc.py +4 -0
  56. flwr/proto/common_pb2_grpc.pyi +4 -0
  57. flwr/proto/driver_pb2.py +24 -19
  58. flwr/proto/driver_pb2.pyi +21 -1
  59. flwr/proto/exec_pb2.py +16 -11
  60. flwr/proto/exec_pb2.pyi +22 -1
  61. flwr/proto/run_pb2.py +12 -7
  62. flwr/proto/run_pb2.pyi +22 -1
  63. flwr/proto/task_pb2.py +7 -8
  64. flwr/server/__init__.py +2 -0
  65. flwr/server/compat/legacy_context.py +5 -4
  66. flwr/server/driver/grpc_driver.py +82 -140
  67. flwr/server/run_serverapp.py +40 -15
  68. flwr/server/server_app.py +56 -10
  69. flwr/server/serverapp_components.py +52 -0
  70. flwr/server/superlink/driver/driver_servicer.py +18 -3
  71. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  72. flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
  73. flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
  74. flwr/server/superlink/fleet/vce/vce_api.py +149 -122
  75. flwr/server/superlink/state/in_memory_state.py +15 -7
  76. flwr/server/superlink/state/sqlite_state.py +27 -12
  77. flwr/server/superlink/state/state.py +7 -2
  78. flwr/server/superlink/state/utils.py +6 -0
  79. flwr/server/typing.py +2 -0
  80. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  81. flwr/simulation/app.py +52 -36
  82. flwr/simulation/ray_transport/ray_actor.py +15 -19
  83. flwr/simulation/ray_transport/ray_client_proxy.py +33 -13
  84. flwr/simulation/run_simulation.py +237 -66
  85. flwr/superexec/app.py +14 -7
  86. flwr/superexec/deployment.py +186 -0
  87. flwr/superexec/exec_grpc.py +5 -1
  88. flwr/superexec/exec_servicer.py +4 -1
  89. flwr/superexec/executor.py +18 -0
  90. flwr/superexec/simulation.py +151 -0
  91. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
  92. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +95 -88
  93. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
  94. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
  95. {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
  """Client-side message handler."""
16
16
 
17
-
18
17
  from logging import WARN
19
18
  from typing import Optional, Tuple, cast
20
19
 
@@ -25,7 +24,7 @@ from flwr.client.client import (
25
24
  maybe_call_get_properties,
26
25
  )
27
26
  from flwr.client.numpy_client import NumPyClient
28
- from flwr.client.typing import ClientFn
27
+ from flwr.client.typing import ClientFnExt
29
28
  from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log
30
29
  from flwr.common.constant import MessageType, MessageTypeLegacy
31
30
  from flwr.common.recordset_compat import (
@@ -90,10 +89,10 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
90
89
 
91
90
 
92
91
  def handle_legacy_message_from_msgtype(
93
- client_fn: ClientFn, message: Message, context: Context
92
+ client_fn: ClientFnExt, message: Message, context: Context
94
93
  ) -> Message:
95
94
  """Handle legacy message in the inner most mod."""
96
- client = client_fn(str(message.metadata.partition_id))
95
+ client = client_fn(context)
97
96
 
98
97
  # Check if NumPyClient is returend
99
98
  if isinstance(client, NumPyClient):
flwr/client/node_state.py CHANGED
@@ -15,27 +15,72 @@
15
15
  """Node state."""
16
16
 
17
17
 
18
- from typing import Any, Dict
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+ from typing import Dict, Optional
19
21
 
20
22
  from flwr.common import Context, RecordSet
23
+ from flwr.common.config import get_fused_config, get_fused_config_from_dir
24
+ from flwr.common.typing import Run, UserConfig
25
+
26
+
27
+ @dataclass()
28
+ class RunInfo:
29
+ """Contains the Context and initial run_config of a Run."""
30
+
31
+ context: Context
32
+ initial_run_config: UserConfig
21
33
 
22
34
 
23
35
  class NodeState:
24
36
  """State of a node where client nodes execute runs."""
25
37
 
26
- def __init__(self) -> None:
27
- self._meta: Dict[str, Any] = {} # holds metadata about the node
28
- self.run_contexts: Dict[int, Context] = {}
38
+ def __init__(
39
+ self,
40
+ node_id: int,
41
+ node_config: UserConfig,
42
+ ) -> None:
43
+ self.node_id = node_id
44
+ self.node_config = node_config
45
+ self.run_infos: Dict[int, RunInfo] = {}
29
46
 
30
- def register_context(self, run_id: int) -> None:
47
+ def register_context(
48
+ self,
49
+ run_id: int,
50
+ run: Optional[Run] = None,
51
+ flwr_path: Optional[Path] = None,
52
+ app_dir: Optional[str] = None,
53
+ ) -> None:
31
54
  """Register new run context for this node."""
32
- if run_id not in self.run_contexts:
33
- self.run_contexts[run_id] = Context(state=RecordSet())
55
+ if run_id not in self.run_infos:
56
+ initial_run_config = {}
57
+ if app_dir:
58
+ # Load from app directory
59
+ app_path = Path(app_dir)
60
+ if app_path.is_dir():
61
+ override_config = run.override_config if run else {}
62
+ initial_run_config = get_fused_config_from_dir(
63
+ app_path, override_config
64
+ )
65
+ else:
66
+ raise ValueError("The specified `app_dir` must be a directory.")
67
+ else:
68
+ # Load from .fab
69
+ initial_run_config = get_fused_config(run, flwr_path) if run else {}
70
+ self.run_infos[run_id] = RunInfo(
71
+ initial_run_config=initial_run_config,
72
+ context=Context(
73
+ node_id=self.node_id,
74
+ node_config=self.node_config,
75
+ state=RecordSet(),
76
+ run_config=initial_run_config.copy(),
77
+ ),
78
+ )
34
79
 
35
80
  def retrieve_context(self, run_id: int) -> Context:
36
81
  """Get run context given a run_id."""
37
- if run_id in self.run_contexts:
38
- return self.run_contexts[run_id]
82
+ if run_id in self.run_infos:
83
+ return self.run_infos[run_id].context
39
84
 
40
85
  raise RuntimeError(
41
86
  f"Context for run_id={run_id} doesn't exist."
@@ -45,4 +90,9 @@ class NodeState:
45
90
 
46
91
  def update_context(self, run_id: int, context: Context) -> None:
47
92
  """Update run context."""
48
- self.run_contexts[run_id] = context
93
+ if context.run_config != self.run_infos[run_id].initial_run_config:
94
+ raise ValueError(
95
+ "The `run_config` field of the `Context` object cannot be "
96
+ f"modified (run_id: {run_id})."
97
+ )
98
+ self.run_infos[run_id].context = context
@@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None:
41
41
  expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}
42
42
 
43
43
  # NodeState
44
- node_state = NodeState()
44
+ node_state = NodeState(node_id=0, node_config={})
45
45
 
46
46
  for task in tasks:
47
47
  run_id = task.run_id
@@ -59,7 +59,8 @@ def test_multirun_in_node_state() -> None:
59
59
  node_state.update_context(run_id=run_id, context=updated_state)
60
60
 
61
61
  # Verify values
62
- for run_id, context in node_state.run_contexts.items():
62
+ for run_id, run_info in node_state.run_infos.items():
63
63
  assert (
64
- context.state.configs_records["counter"]["count"] == expected_values[run_id]
64
+ run_info.context.state.configs_records["counter"]["count"]
65
+ == expected_values[run_id]
65
66
  )
@@ -40,7 +40,12 @@ from flwr.common.constant import (
40
40
  from flwr.common.logger import log
41
41
  from flwr.common.message import Message, Metadata
42
42
  from flwr.common.retry_invoker import RetryInvoker
43
- from flwr.common.serde import message_from_taskins, message_to_taskres
43
+ from flwr.common.serde import (
44
+ message_from_taskins,
45
+ message_to_taskres,
46
+ user_config_from_proto,
47
+ )
48
+ from flwr.common.typing import Run
44
49
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
45
50
  CreateNodeRequest,
46
51
  CreateNodeResponse,
@@ -89,9 +94,9 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
89
94
  Tuple[
90
95
  Callable[[], Optional[Message]],
91
96
  Callable[[Message], None],
97
+ Optional[Callable[[], Optional[int]]],
92
98
  Optional[Callable[[], None]],
93
- Optional[Callable[[], None]],
94
- Optional[Callable[[int], Tuple[str, str]]],
99
+ Optional[Callable[[int], Run]],
95
100
  ]
96
101
  ]:
97
102
  """Primitives for request/response-based interaction with a server.
@@ -236,19 +241,20 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
236
241
  if not ping_stop_event.is_set():
237
242
  ping_stop_event.wait(next_interval)
238
243
 
239
- def create_node() -> None:
244
+ def create_node() -> Optional[int]:
240
245
  """Set create_node."""
241
246
  req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
242
247
 
243
248
  # Send the request
244
249
  res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
245
250
  if res is None:
246
- return
251
+ return None
247
252
 
248
253
  # Remember the node and the ping-loop thread
249
254
  nonlocal node, ping_thread
250
255
  node = res.node
251
256
  ping_thread = start_ping_loop(ping, ping_stop_event)
257
+ return node.node_id
252
258
 
253
259
  def delete_node() -> None:
254
260
  """Set delete_node."""
@@ -344,16 +350,21 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
344
350
  res.results, # pylint: disable=no-member
345
351
  )
346
352
 
347
- def get_run(run_id: int) -> Tuple[str, str]:
353
+ def get_run(run_id: int) -> Run:
348
354
  # Construct the request
349
355
  req = GetRunRequest(run_id=run_id)
350
356
 
351
357
  # Send the request
352
358
  res = _request(req, GetRunResponse, PATH_GET_RUN)
353
359
  if res is None:
354
- return "", ""
360
+ return Run(run_id, "", "", {})
355
361
 
356
- return res.run.fab_id, res.run.fab_version
362
+ return Run(
363
+ run_id,
364
+ res.run.fab_id,
365
+ res.run.fab_version,
366
+ user_config_from_proto(res.run.override_config),
367
+ )
357
368
 
358
369
  try:
359
370
  # Yield methods
@@ -29,7 +29,12 @@ from cryptography.hazmat.primitives.serialization import (
29
29
 
30
30
  from flwr.client.client_app import ClientApp, LoadClientAppError
31
31
  from flwr.common import EventType, event
32
- from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
32
+ from flwr.common.config import (
33
+ get_flwr_dir,
34
+ get_project_config,
35
+ get_project_dir,
36
+ parse_config_args,
37
+ )
33
38
  from flwr.common.constant import (
34
39
  TRANSPORT_TYPE_GRPC_ADAPTER,
35
40
  TRANSPORT_TYPE_GRPC_RERE,
@@ -55,7 +60,12 @@ def run_supernode() -> None:
55
60
  _warn_deprecated_server_arg(args)
56
61
 
57
62
  root_certificates = _get_certificates(args)
58
- load_fn = _get_load_client_app_fn(args, multi_app=True)
63
+ load_fn = _get_load_client_app_fn(
64
+ default_app_ref=getattr(args, "client-app"),
65
+ dir_arg=args.dir,
66
+ flwr_dir_arg=args.flwr_dir,
67
+ multi_app=True,
68
+ )
59
69
  authentication_keys = _try_setup_client_authentication(args)
60
70
 
61
71
  _start_client_internal(
@@ -67,6 +77,8 @@ def run_supernode() -> None:
67
77
  authentication_keys=authentication_keys,
68
78
  max_retries=args.max_retries,
69
79
  max_wait_time=args.max_wait_time,
80
+ node_config=parse_config_args([args.node_config]),
81
+ flwr_path=get_flwr_dir(args.flwr_dir),
70
82
  )
71
83
 
72
84
  # Graceful shutdown
@@ -86,11 +98,16 @@ def run_client_app() -> None:
86
98
  _warn_deprecated_server_arg(args)
87
99
 
88
100
  root_certificates = _get_certificates(args)
89
- load_fn = _get_load_client_app_fn(args, multi_app=False)
101
+ load_fn = _get_load_client_app_fn(
102
+ default_app_ref=getattr(args, "client-app"),
103
+ dir_arg=args.dir,
104
+ multi_app=False,
105
+ )
90
106
  authentication_keys = _try_setup_client_authentication(args)
91
107
 
92
108
  _start_client_internal(
93
109
  server_address=args.superlink,
110
+ node_config=parse_config_args([args.node_config]),
94
111
  load_client_app_fn=load_fn,
95
112
  transport=args.transport,
96
113
  root_certificates=root_certificates,
@@ -158,7 +175,10 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
158
175
 
159
176
 
160
177
  def _get_load_client_app_fn(
161
- args: argparse.Namespace, multi_app: bool
178
+ default_app_ref: str,
179
+ dir_arg: str,
180
+ multi_app: bool,
181
+ flwr_dir_arg: Optional[str] = None,
162
182
  ) -> Callable[[str, str], ClientApp]:
163
183
  """Get the load_client_app_fn function.
164
184
 
@@ -170,23 +190,27 @@ def _get_load_client_app_fn(
170
190
  loads a default ClientApp.
171
191
  """
172
192
  # Find the Flower directory containing Flower Apps (only for multi-app)
173
- flwr_dir = Path("")
174
- if "flwr_dir" in args:
175
- if args.flwr_dir is None:
193
+ if not multi_app:
194
+ flwr_dir = Path("")
195
+ else:
196
+ if flwr_dir_arg is None:
176
197
  flwr_dir = get_flwr_dir()
177
198
  else:
178
- flwr_dir = Path(args.flwr_dir).absolute()
179
-
180
- sys.path.insert(0, str(flwr_dir.absolute()))
199
+ flwr_dir = Path(flwr_dir_arg).absolute()
181
200
 
182
- default_app_ref: str = getattr(args, "client-app")
201
+ inserted_path = None
183
202
 
184
203
  if not multi_app:
185
204
  log(
186
205
  DEBUG,
187
206
  "Flower SuperNode will load and validate ClientApp `%s`",
188
- getattr(args, "client-app"),
207
+ default_app_ref,
189
208
  )
209
+ # Insert sys.path
210
+ dir_path = Path(dir_arg).absolute()
211
+ sys.path.insert(0, str(dir_path))
212
+ inserted_path = str(dir_path)
213
+
190
214
  valid, error_msg = validate(default_app_ref)
191
215
  if not valid and error_msg:
192
216
  raise LoadClientAppError(error_msg) from None
@@ -195,7 +219,7 @@ def _get_load_client_app_fn(
195
219
  # If multi-app feature is disabled
196
220
  if not multi_app:
197
221
  # Get sys path to be inserted
198
- sys_path = Path(args.dir).absolute()
222
+ dir_path = Path(dir_arg).absolute()
199
223
 
200
224
  # Set app reference
201
225
  client_app_ref = default_app_ref
@@ -208,7 +232,7 @@ def _get_load_client_app_fn(
208
232
 
209
233
  log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
210
234
  # Get sys path to be inserted
211
- sys_path = Path(args.dir).absolute()
235
+ dir_path = Path(dir_arg).absolute()
212
236
 
213
237
  # Set app reference
214
238
  client_app_ref = default_app_ref
@@ -221,13 +245,21 @@ def _get_load_client_app_fn(
221
245
  raise LoadClientAppError("Failed to load ClientApp") from e
222
246
 
223
247
  # Get sys path to be inserted
224
- sys_path = Path(project_dir).absolute()
248
+ dir_path = Path(project_dir).absolute()
225
249
 
226
250
  # Set app reference
227
- client_app_ref = config["flower"]["components"]["clientapp"]
251
+ client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
228
252
 
229
253
  # Set sys.path
230
- sys.path.insert(0, str(sys_path))
254
+ nonlocal inserted_path
255
+ if inserted_path != str(dir_path):
256
+ # Remove the previously inserted path
257
+ if inserted_path is not None:
258
+ sys.path.remove(inserted_path)
259
+ # Insert the new path
260
+ sys.path.insert(0, str(dir_path))
261
+
262
+ inserted_path = str(dir_path)
231
263
 
232
264
  # Load ClientApp
233
265
  log(
@@ -235,7 +267,7 @@ def _get_load_client_app_fn(
235
267
  "Loading ClientApp `%s`",
236
268
  client_app_ref,
237
269
  )
238
- client_app = load_app(client_app_ref, LoadClientAppError, sys_path)
270
+ client_app = load_app(client_app_ref, LoadClientAppError, dir_path)
239
271
 
240
272
  if not isinstance(client_app, ClientApp):
241
273
  raise LoadClientAppError(
@@ -344,8 +376,8 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
344
376
  "--max-retries",
345
377
  type=int,
346
378
  default=None,
347
- help="The maximum number of times the client will try to connect to the"
348
- "server before giving up in case of a connection error. By default,"
379
+ help="The maximum number of times the client will try to reconnect to the"
380
+ "SuperLink before giving up in case of a connection error. By default,"
349
381
  "it is set to None, meaning there is no limit to the number of tries.",
350
382
  )
351
383
  parser.add_argument(
@@ -353,7 +385,7 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
353
385
  type=float,
354
386
  default=None,
355
387
  help="The maximum duration before the client stops trying to"
356
- "connect to the server in case of connection error. By default, it"
388
+ "connect to the SuperLink in case of connection error. By default, it"
357
389
  "is set to None, meaning there is no limit to the total time.",
358
390
  )
359
391
  parser.add_argument(
@@ -373,6 +405,13 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
373
405
  type=str,
374
406
  help="The SuperNode's public key (as a path str) to enable authentication.",
375
407
  )
408
+ parser.add_argument(
409
+ "--node-config",
410
+ type=str,
411
+ help="A comma separated list of key/value pairs (separated by `=`) to "
412
+ "configure the SuperNode. "
413
+ "E.g. --node-config 'key1=\"value1\",partition-id=0,num-partitions=100'",
414
+ )
376
415
 
377
416
 
378
417
  def _try_setup_client_authentication(
flwr/client/typing.py CHANGED
@@ -23,6 +23,7 @@ from .client import Client as Client
23
23
 
24
24
  # Compatibility
25
25
  ClientFn = Callable[[str], Client]
26
+ ClientFnExt = Callable[[Context], Client]
26
27
 
27
28
  ClientAppCallable = Callable[[Message, Context], Message]
28
29
  Mod = Callable[[Message, Context, ClientAppCallable], Message]
flwr/common/config.py CHANGED
@@ -16,12 +16,13 @@
16
16
 
17
17
  import os
18
18
  from pathlib import Path
19
- from typing import Any, Dict, Optional, Union
19
+ from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args
20
20
 
21
21
  import tomli
22
22
 
23
23
  from flwr.cli.config_utils import validate_fields
24
24
  from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
25
+ from flwr.common.typing import Run, UserConfig, UserConfigValue
25
26
 
26
27
 
27
28
  def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
@@ -30,7 +31,7 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
30
31
  return Path(
31
32
  os.getenv(
32
33
  FLWR_HOME,
33
- f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr",
34
+ Path(f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}") / ".flwr",
34
35
  )
35
36
  )
36
37
  return Path(provided_path).absolute()
@@ -71,3 +72,87 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]:
71
72
  )
72
73
 
73
74
  return config
75
+
76
+
77
+ def _fuse_dicts(
78
+ main_dict: UserConfig,
79
+ override_dict: UserConfig,
80
+ ) -> UserConfig:
81
+ fused_dict = main_dict.copy()
82
+
83
+ for key, value in override_dict.items():
84
+ if key in main_dict:
85
+ fused_dict[key] = value
86
+
87
+ return fused_dict
88
+
89
+
90
+ def get_fused_config_from_dir(
91
+ project_dir: Path, override_config: UserConfig
92
+ ) -> UserConfig:
93
+ """Merge the overrides from a given dict with the config from a Flower App."""
94
+ default_config = get_project_config(project_dir)["tool"]["flwr"]["app"].get(
95
+ "config", {}
96
+ )
97
+ flat_default_config = flatten_dict(default_config)
98
+
99
+ return _fuse_dicts(flat_default_config, override_config)
100
+
101
+
102
+ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
103
+ """Merge the overrides from a `Run` with the config from a FAB.
104
+
105
+ Get the config using the fab_id and the fab_version, remove the nesting by adding
106
+ the nested keys as prefixes separated by dots, and fuse it with the override dict.
107
+ """
108
+ if not run.fab_id or not run.fab_version:
109
+ return {}
110
+
111
+ project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir)
112
+
113
+ return get_fused_config_from_dir(project_dir, run.override_config)
114
+
115
+
116
+ def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> UserConfig:
117
+ """Flatten dict by joining nested keys with a given separator."""
118
+ items: List[Tuple[str, UserConfigValue]] = []
119
+ separator: str = "."
120
+ for k, v in raw_dict.items():
121
+ new_key = f"{parent_key}{separator}{k}" if parent_key else k
122
+ if isinstance(v, dict):
123
+ items.extend(flatten_dict(v, parent_key=new_key).items())
124
+ elif isinstance(v, get_args(UserConfigValue)):
125
+ items.append((new_key, cast(UserConfigValue, v)))
126
+ else:
127
+ raise ValueError(
128
+ f"The value for key {k} needs to be of type `int`, `float`, "
129
+ "`bool, `str`, or a `dict` of those.",
130
+ )
131
+ return dict(items)
132
+
133
+
134
+ def parse_config_args(
135
+ config: Optional[List[str]],
136
+ separator: str = ",",
137
+ ) -> UserConfig:
138
+ """Parse separator separated list of key-value pairs separated by '='."""
139
+ overrides: UserConfig = {}
140
+
141
+ if config is None:
142
+ return overrides
143
+
144
+ for config_line in config:
145
+ if config_line:
146
+ overrides_list = config_line.split(separator)
147
+ if (
148
+ len(overrides_list) == 1
149
+ and "=" not in overrides_list
150
+ and overrides_list[0].endswith(".toml")
151
+ ):
152
+ with Path(overrides_list[0]).open("rb") as config_file:
153
+ overrides = flatten_dict(tomli.load(config_file))
154
+ else:
155
+ toml_str = "\n".join(overrides_list)
156
+ overrides.update(tomli.loads(toml_str))
157
+
158
+ return overrides
flwr/common/constant.py CHANGED
@@ -46,6 +46,9 @@ PING_BASE_MULTIPLIER = 0.8
46
46
  PING_RANDOM_RANGE = (-0.1, 0.1)
47
47
  PING_MAX_INTERVAL = 1e300
48
48
 
49
+ # IDs
50
+ RUN_ID_NUM_BYTES = 8
51
+ NODE_ID_NUM_BYTES = 8
49
52
  GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
50
53
  GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
51
54
 
@@ -54,6 +57,9 @@ APP_DIR = "apps"
54
57
  FAB_CONFIG_FILE = "pyproject.toml"
55
58
  FLWR_HOME = "FLWR_HOME"
56
59
 
60
+ # Constants entries in Node config for Simulation
61
+ PARTITION_ID_KEY = "partition-id"
62
+ NUM_PARTITIONS_KEY = "num-partitions"
57
63
 
58
64
  GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
59
65
  GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
flwr/common/context.py CHANGED
@@ -18,14 +18,20 @@
18
18
  from dataclasses import dataclass
19
19
 
20
20
  from .record import RecordSet
21
+ from .typing import UserConfig
21
22
 
22
23
 
23
24
  @dataclass
24
25
  class Context:
25
- """State of your run.
26
+ """Context of your run.
26
27
 
27
28
  Parameters
28
29
  ----------
30
+ node_id : int
31
+ The ID that identifies the node.
32
+ node_config : UserConfig
33
+ A config (key/value mapping) unique to the node and independent of the
34
+ `run_config`. This config persists across all runs this node participates in.
29
35
  state : RecordSet
30
36
  Holds records added by the entity in a given run and that will stay local.
31
37
  This means that the data it holds will never leave the system it's running from.
@@ -33,6 +39,25 @@ class Context:
33
39
  executing mods. It can also be used as a memory to access
34
40
  at different points during the lifecycle of this entity (e.g. across
35
41
  multiple rounds)
42
+ run_config : UserConfig
43
+ A config (key/value mapping) held by the entity in a given run and that will
44
+ stay local. It can be used at any point during the lifecycle of this entity
45
+ (e.g. across multiple rounds)
36
46
  """
37
47
 
48
+ node_id: int
49
+ node_config: UserConfig
38
50
  state: RecordSet
51
+ run_config: UserConfig
52
+
53
+ def __init__( # pylint: disable=too-many-arguments
54
+ self,
55
+ node_id: int,
56
+ node_config: UserConfig,
57
+ state: RecordSet,
58
+ run_config: UserConfig,
59
+ ) -> None:
60
+ self.node_id = node_id
61
+ self.node_config = node_config
62
+ self.state = state
63
+ self.run_config = run_config
flwr/common/logger.py CHANGED
@@ -197,6 +197,44 @@ def warn_deprecated_feature(name: str) -> None:
197
197
  )
198
198
 
199
199
 
200
+ def warn_deprecated_feature_with_example(
201
+ deprecation_message: str, example_message: str, code_example: str
202
+ ) -> None:
203
+ """Warn if a feature is deprecated and show code example."""
204
+ log(
205
+ WARN,
206
+ """DEPRECATED FEATURE: %s
207
+
208
+ Check the following `FEATURE UPDATE` warning message for the preferred
209
+ new mechanism to use this feature in Flower.
210
+ """,
211
+ deprecation_message,
212
+ )
213
+ log(
214
+ WARN,
215
+ """FEATURE UPDATE: %s
216
+ ------------------------------------------------------------
217
+ %s
218
+ ------------------------------------------------------------
219
+ """,
220
+ example_message,
221
+ code_example,
222
+ )
223
+
224
+
225
+ def warn_unsupported_feature(name: str) -> None:
226
+ """Warn the user when they use an unsupported feature."""
227
+ log(
228
+ WARN,
229
+ """UNSUPPORTED FEATURE: %s
230
+
231
+ This is an unsupported feature. It will be removed
232
+ entirely in future versions of Flower.
233
+ """,
234
+ name,
235
+ )
236
+
237
+
200
238
  def set_logger_propagation(
201
239
  child_logger: logging.Logger, value: bool = True
202
240
  ) -> logging.Logger:
flwr/common/message.py CHANGED
@@ -48,10 +48,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
48
48
  message_type : str
49
49
  A string that encodes the action to be executed on
50
50
  the receiving end.
51
- partition_id : Optional[int]
52
- An identifier that can be used when loading a particular
53
- data partition for a ClientApp. Making use of this identifier
54
- is more relevant when conducting simulations.
55
51
  """
56
52
 
57
53
  def __init__( # pylint: disable=too-many-arguments
@@ -64,7 +60,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
64
60
  group_id: str,
65
61
  ttl: float,
66
62
  message_type: str,
67
- partition_id: int | None = None,
68
63
  ) -> None:
69
64
  var_dict = {
70
65
  "_run_id": run_id,
@@ -75,7 +70,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
75
70
  "_group_id": group_id,
76
71
  "_ttl": ttl,
77
72
  "_message_type": message_type,
78
- "_partition_id": partition_id,
79
73
  }
80
74
  self.__dict__.update(var_dict)
81
75
 
@@ -149,16 +143,6 @@ class Metadata: # pylint: disable=too-many-instance-attributes
149
143
  """Set message_type."""
150
144
  self.__dict__["_message_type"] = value
151
145
 
152
- @property
153
- def partition_id(self) -> int | None:
154
- """An identifier telling which data partition a ClientApp should use."""
155
- return cast(int, self.__dict__["_partition_id"])
156
-
157
- @partition_id.setter
158
- def partition_id(self, value: int) -> None:
159
- """Set partition_id."""
160
- self.__dict__["_partition_id"] = value
161
-
162
146
  def __repr__(self) -> str:
163
147
  """Return a string representation of this instance."""
164
148
  view = ", ".join([f"{k.lstrip('_')}={v!r}" for k, v in self.__dict__.items()])
@@ -398,5 +382,4 @@ def _create_reply_metadata(msg: Message, ttl: float) -> Metadata:
398
382
  group_id=msg.metadata.group_id,
399
383
  ttl=ttl,
400
384
  message_type=msg.metadata.message_type,
401
- partition_id=msg.metadata.partition_id,
402
385
  )