flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (92) hide show
  1. flwr/cli/build.py +16 -2
  2. flwr/cli/config_utils.py +36 -14
  3. flwr/cli/install.py +17 -1
  4. flwr/cli/new/new.py +31 -20
  5. flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
  10. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
  12. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
  13. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
  14. flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
  15. flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
  16. flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
  17. flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
  18. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
  19. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
  20. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
  21. flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
  22. flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
  23. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
  25. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
  26. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
  27. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
  28. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
  29. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
  30. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
  31. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
  32. flwr/cli/run/run.py +128 -53
  33. flwr/client/app.py +56 -24
  34. flwr/client/client_app.py +28 -8
  35. flwr/client/grpc_adapter_client/connection.py +3 -2
  36. flwr/client/grpc_client/connection.py +3 -2
  37. flwr/client/grpc_rere_client/connection.py +17 -6
  38. flwr/client/message_handler/message_handler.py +1 -1
  39. flwr/client/node_state.py +59 -12
  40. flwr/client/node_state_tests.py +4 -3
  41. flwr/client/rest_client/connection.py +19 -8
  42. flwr/client/supernode/app.py +55 -24
  43. flwr/client/typing.py +2 -2
  44. flwr/common/config.py +87 -2
  45. flwr/common/constant.py +3 -0
  46. flwr/common/context.py +24 -9
  47. flwr/common/logger.py +25 -0
  48. flwr/common/serde.py +45 -0
  49. flwr/common/telemetry.py +17 -0
  50. flwr/common/typing.py +5 -0
  51. flwr/proto/common_pb2.py +36 -0
  52. flwr/proto/common_pb2.pyi +121 -0
  53. flwr/proto/common_pb2_grpc.py +4 -0
  54. flwr/proto/common_pb2_grpc.pyi +4 -0
  55. flwr/proto/driver_pb2.py +24 -19
  56. flwr/proto/driver_pb2.pyi +21 -1
  57. flwr/proto/exec_pb2.py +16 -11
  58. flwr/proto/exec_pb2.pyi +22 -1
  59. flwr/proto/run_pb2.py +12 -7
  60. flwr/proto/run_pb2.pyi +22 -1
  61. flwr/proto/task_pb2.py +7 -8
  62. flwr/server/__init__.py +2 -0
  63. flwr/server/compat/legacy_context.py +5 -4
  64. flwr/server/driver/grpc_driver.py +82 -140
  65. flwr/server/run_serverapp.py +40 -15
  66. flwr/server/server_app.py +56 -10
  67. flwr/server/serverapp_components.py +52 -0
  68. flwr/server/superlink/driver/driver_servicer.py +18 -3
  69. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  70. flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
  71. flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
  72. flwr/server/superlink/fleet/vce/vce_api.py +149 -117
  73. flwr/server/superlink/state/in_memory_state.py +11 -3
  74. flwr/server/superlink/state/sqlite_state.py +23 -8
  75. flwr/server/superlink/state/state.py +7 -2
  76. flwr/server/typing.py +2 -0
  77. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  78. flwr/simulation/app.py +4 -3
  79. flwr/simulation/ray_transport/ray_actor.py +15 -19
  80. flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
  81. flwr/simulation/run_simulation.py +237 -66
  82. flwr/superexec/app.py +14 -7
  83. flwr/superexec/deployment.py +110 -33
  84. flwr/superexec/exec_grpc.py +5 -1
  85. flwr/superexec/exec_servicer.py +4 -1
  86. flwr/superexec/executor.py +18 -0
  87. flwr/superexec/simulation.py +151 -0
  88. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
  89. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +92 -86
  90. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
  91. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
  92. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
flwr/client/node_state.py CHANGED
@@ -15,30 +15,72 @@
15
15
  """Node state."""
16
16
 
17
17
 
18
- from typing import Any, Dict, Optional
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+ from typing import Dict, Optional
19
21
 
20
22
  from flwr.common import Context, RecordSet
23
+ from flwr.common.config import get_fused_config, get_fused_config_from_dir
24
+ from flwr.common.typing import Run, UserConfig
25
+
26
+
27
+ @dataclass()
28
+ class RunInfo:
29
+ """Contains the Context and initial run_config of a Run."""
30
+
31
+ context: Context
32
+ initial_run_config: UserConfig
21
33
 
22
34
 
23
35
  class NodeState:
24
36
  """State of a node where client nodes execute runs."""
25
37
 
26
- def __init__(self, partition_id: Optional[int]) -> None:
27
- self._meta: Dict[str, Any] = {} # holds metadata about the node
28
- self.run_contexts: Dict[int, Context] = {}
29
- self._partition_id = partition_id
38
+ def __init__(
39
+ self,
40
+ node_id: int,
41
+ node_config: UserConfig,
42
+ ) -> None:
43
+ self.node_id = node_id
44
+ self.node_config = node_config
45
+ self.run_infos: Dict[int, RunInfo] = {}
30
46
 
31
- def register_context(self, run_id: int) -> None:
47
+ def register_context(
48
+ self,
49
+ run_id: int,
50
+ run: Optional[Run] = None,
51
+ flwr_path: Optional[Path] = None,
52
+ app_dir: Optional[str] = None,
53
+ ) -> None:
32
54
  """Register new run context for this node."""
33
- if run_id not in self.run_contexts:
34
- self.run_contexts[run_id] = Context(
35
- state=RecordSet(), partition_id=self._partition_id
55
+ if run_id not in self.run_infos:
56
+ initial_run_config = {}
57
+ if app_dir:
58
+ # Load from app directory
59
+ app_path = Path(app_dir)
60
+ if app_path.is_dir():
61
+ override_config = run.override_config if run else {}
62
+ initial_run_config = get_fused_config_from_dir(
63
+ app_path, override_config
64
+ )
65
+ else:
66
+ raise ValueError("The specified `app_dir` must be a directory.")
67
+ else:
68
+ # Load from .fab
69
+ initial_run_config = get_fused_config(run, flwr_path) if run else {}
70
+ self.run_infos[run_id] = RunInfo(
71
+ initial_run_config=initial_run_config,
72
+ context=Context(
73
+ node_id=self.node_id,
74
+ node_config=self.node_config,
75
+ state=RecordSet(),
76
+ run_config=initial_run_config.copy(),
77
+ ),
36
78
  )
37
79
 
38
80
  def retrieve_context(self, run_id: int) -> Context:
39
81
  """Get run context given a run_id."""
40
- if run_id in self.run_contexts:
41
- return self.run_contexts[run_id]
82
+ if run_id in self.run_infos:
83
+ return self.run_infos[run_id].context
42
84
 
43
85
  raise RuntimeError(
44
86
  f"Context for run_id={run_id} doesn't exist."
@@ -48,4 +90,9 @@ class NodeState:
48
90
 
49
91
  def update_context(self, run_id: int, context: Context) -> None:
50
92
  """Update run context."""
51
- self.run_contexts[run_id] = context
93
+ if context.run_config != self.run_infos[run_id].initial_run_config:
94
+ raise ValueError(
95
+ "The `run_config` field of the `Context` object cannot be "
96
+ f"modified (run_id: {run_id})."
97
+ )
98
+ self.run_infos[run_id].context = context
@@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None:
41
41
  expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}
42
42
 
43
43
  # NodeState
44
- node_state = NodeState(partition_id=None)
44
+ node_state = NodeState(node_id=0, node_config={})
45
45
 
46
46
  for task in tasks:
47
47
  run_id = task.run_id
@@ -59,7 +59,8 @@ def test_multirun_in_node_state() -> None:
59
59
  node_state.update_context(run_id=run_id, context=updated_state)
60
60
 
61
61
  # Verify values
62
- for run_id, context in node_state.run_contexts.items():
62
+ for run_id, run_info in node_state.run_infos.items():
63
63
  assert (
64
- context.state.configs_records["counter"]["count"] == expected_values[run_id]
64
+ run_info.context.state.configs_records["counter"]["count"]
65
+ == expected_values[run_id]
65
66
  )
@@ -40,7 +40,12 @@ from flwr.common.constant import (
40
40
  from flwr.common.logger import log
41
41
  from flwr.common.message import Message, Metadata
42
42
  from flwr.common.retry_invoker import RetryInvoker
43
- from flwr.common.serde import message_from_taskins, message_to_taskres
43
+ from flwr.common.serde import (
44
+ message_from_taskins,
45
+ message_to_taskres,
46
+ user_config_from_proto,
47
+ )
48
+ from flwr.common.typing import Run
44
49
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
45
50
  CreateNodeRequest,
46
51
  CreateNodeResponse,
@@ -89,9 +94,9 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
89
94
  Tuple[
90
95
  Callable[[], Optional[Message]],
91
96
  Callable[[Message], None],
97
+ Optional[Callable[[], Optional[int]]],
92
98
  Optional[Callable[[], None]],
93
- Optional[Callable[[], None]],
94
- Optional[Callable[[int], Tuple[str, str]]],
99
+ Optional[Callable[[int], Run]],
95
100
  ]
96
101
  ]:
97
102
  """Primitives for request/response-based interaction with a server.
@@ -236,19 +241,20 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
236
241
  if not ping_stop_event.is_set():
237
242
  ping_stop_event.wait(next_interval)
238
243
 
239
- def create_node() -> None:
244
+ def create_node() -> Optional[int]:
240
245
  """Set create_node."""
241
246
  req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
242
247
 
243
248
  # Send the request
244
249
  res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
245
250
  if res is None:
246
- return
251
+ return None
247
252
 
248
253
  # Remember the node and the ping-loop thread
249
254
  nonlocal node, ping_thread
250
255
  node = res.node
251
256
  ping_thread = start_ping_loop(ping, ping_stop_event)
257
+ return node.node_id
252
258
 
253
259
  def delete_node() -> None:
254
260
  """Set delete_node."""
@@ -344,16 +350,21 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
344
350
  res.results, # pylint: disable=no-member
345
351
  )
346
352
 
347
- def get_run(run_id: int) -> Tuple[str, str]:
353
+ def get_run(run_id: int) -> Run:
348
354
  # Construct the request
349
355
  req = GetRunRequest(run_id=run_id)
350
356
 
351
357
  # Send the request
352
358
  res = _request(req, GetRunResponse, PATH_GET_RUN)
353
359
  if res is None:
354
- return "", ""
360
+ return Run(run_id, "", "", {})
355
361
 
356
- return res.run.fab_id, res.run.fab_version
362
+ return Run(
363
+ run_id,
364
+ res.run.fab_id,
365
+ res.run.fab_version,
366
+ user_config_from_proto(res.run.override_config),
367
+ )
357
368
 
358
369
  try:
359
370
  # Yield methods
@@ -29,7 +29,12 @@ from cryptography.hazmat.primitives.serialization import (
29
29
 
30
30
  from flwr.client.client_app import ClientApp, LoadClientAppError
31
31
  from flwr.common import EventType, event
32
- from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
32
+ from flwr.common.config import (
33
+ get_flwr_dir,
34
+ get_project_config,
35
+ get_project_dir,
36
+ parse_config_args,
37
+ )
33
38
  from flwr.common.constant import (
34
39
  TRANSPORT_TYPE_GRPC_ADAPTER,
35
40
  TRANSPORT_TYPE_GRPC_RERE,
@@ -55,7 +60,12 @@ def run_supernode() -> None:
55
60
  _warn_deprecated_server_arg(args)
56
61
 
57
62
  root_certificates = _get_certificates(args)
58
- load_fn = _get_load_client_app_fn(args, multi_app=True)
63
+ load_fn = _get_load_client_app_fn(
64
+ default_app_ref=getattr(args, "client-app"),
65
+ dir_arg=args.dir,
66
+ flwr_dir_arg=args.flwr_dir,
67
+ multi_app=True,
68
+ )
59
69
  authentication_keys = _try_setup_client_authentication(args)
60
70
 
61
71
  _start_client_internal(
@@ -67,7 +77,8 @@ def run_supernode() -> None:
67
77
  authentication_keys=authentication_keys,
68
78
  max_retries=args.max_retries,
69
79
  max_wait_time=args.max_wait_time,
70
- partition_id=args.partition_id,
80
+ node_config=parse_config_args([args.node_config]),
81
+ flwr_path=get_flwr_dir(args.flwr_dir),
71
82
  )
72
83
 
73
84
  # Graceful shutdown
@@ -87,11 +98,16 @@ def run_client_app() -> None:
87
98
  _warn_deprecated_server_arg(args)
88
99
 
89
100
  root_certificates = _get_certificates(args)
90
- load_fn = _get_load_client_app_fn(args, multi_app=False)
101
+ load_fn = _get_load_client_app_fn(
102
+ default_app_ref=getattr(args, "client-app"),
103
+ dir_arg=args.dir,
104
+ multi_app=False,
105
+ )
91
106
  authentication_keys = _try_setup_client_authentication(args)
92
107
 
93
108
  _start_client_internal(
94
109
  server_address=args.superlink,
110
+ node_config=parse_config_args([args.node_config]),
95
111
  load_client_app_fn=load_fn,
96
112
  transport=args.transport,
97
113
  root_certificates=root_certificates,
@@ -159,7 +175,10 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
159
175
 
160
176
 
161
177
  def _get_load_client_app_fn(
162
- args: argparse.Namespace, multi_app: bool
178
+ default_app_ref: str,
179
+ dir_arg: str,
180
+ multi_app: bool,
181
+ flwr_dir_arg: Optional[str] = None,
163
182
  ) -> Callable[[str, str], ClientApp]:
164
183
  """Get the load_client_app_fn function.
165
184
 
@@ -171,23 +190,27 @@ def _get_load_client_app_fn(
171
190
  loads a default ClientApp.
172
191
  """
173
192
  # Find the Flower directory containing Flower Apps (only for multi-app)
174
- flwr_dir = Path("")
175
- if "flwr_dir" in args:
176
- if args.flwr_dir is None:
193
+ if not multi_app:
194
+ flwr_dir = Path("")
195
+ else:
196
+ if flwr_dir_arg is None:
177
197
  flwr_dir = get_flwr_dir()
178
198
  else:
179
- flwr_dir = Path(args.flwr_dir).absolute()
199
+ flwr_dir = Path(flwr_dir_arg).absolute()
180
200
 
181
- sys.path.insert(0, str(flwr_dir.absolute()))
182
-
183
- default_app_ref: str = getattr(args, "client-app")
201
+ inserted_path = None
184
202
 
185
203
  if not multi_app:
186
204
  log(
187
205
  DEBUG,
188
206
  "Flower SuperNode will load and validate ClientApp `%s`",
189
- getattr(args, "client-app"),
207
+ default_app_ref,
190
208
  )
209
+ # Insert sys.path
210
+ dir_path = Path(dir_arg).absolute()
211
+ sys.path.insert(0, str(dir_path))
212
+ inserted_path = str(dir_path)
213
+
191
214
  valid, error_msg = validate(default_app_ref)
192
215
  if not valid and error_msg:
193
216
  raise LoadClientAppError(error_msg) from None
@@ -196,7 +219,7 @@ def _get_load_client_app_fn(
196
219
  # If multi-app feature is disabled
197
220
  if not multi_app:
198
221
  # Get sys path to be inserted
199
- sys_path = Path(args.dir).absolute()
222
+ dir_path = Path(dir_arg).absolute()
200
223
 
201
224
  # Set app reference
202
225
  client_app_ref = default_app_ref
@@ -209,7 +232,7 @@ def _get_load_client_app_fn(
209
232
 
210
233
  log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
211
234
  # Get sys path to be inserted
212
- sys_path = Path(args.dir).absolute()
235
+ dir_path = Path(dir_arg).absolute()
213
236
 
214
237
  # Set app reference
215
238
  client_app_ref = default_app_ref
@@ -222,13 +245,21 @@ def _get_load_client_app_fn(
222
245
  raise LoadClientAppError("Failed to load ClientApp") from e
223
246
 
224
247
  # Get sys path to be inserted
225
- sys_path = Path(project_dir).absolute()
248
+ dir_path = Path(project_dir).absolute()
226
249
 
227
250
  # Set app reference
228
- client_app_ref = config["flower"]["components"]["clientapp"]
251
+ client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
229
252
 
230
253
  # Set sys.path
231
- sys.path.insert(0, str(sys_path))
254
+ nonlocal inserted_path
255
+ if inserted_path != str(dir_path):
256
+ # Remove the previously inserted path
257
+ if inserted_path is not None:
258
+ sys.path.remove(inserted_path)
259
+ # Insert the new path
260
+ sys.path.insert(0, str(dir_path))
261
+
262
+ inserted_path = str(dir_path)
232
263
 
233
264
  # Load ClientApp
234
265
  log(
@@ -236,7 +267,7 @@ def _get_load_client_app_fn(
236
267
  "Loading ClientApp `%s`",
237
268
  client_app_ref,
238
269
  )
239
- client_app = load_app(client_app_ref, LoadClientAppError, sys_path)
270
+ client_app = load_app(client_app_ref, LoadClientAppError, dir_path)
240
271
 
241
272
  if not isinstance(client_app, ClientApp):
242
273
  raise LoadClientAppError(
@@ -375,11 +406,11 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
375
406
  help="The SuperNode's public key (as a path str) to enable authentication.",
376
407
  )
377
408
  parser.add_argument(
378
- "--partition-id",
379
- type=int,
380
- help="The data partition index associated with this SuperNode. Better suited "
381
- "for prototyping purposes where a SuperNode might only load a fraction of an "
382
- "artificially partitioned dataset (e.g. using `flwr-datasets`)",
409
+ "--node-config",
410
+ type=str,
411
+ help="A comma separated list of key/value pairs (separated by `=`) to "
412
+ "configure the SuperNode. "
413
+ "E.g. --node-config 'key1=\"value1\",partition-id=0,num-partitions=100'",
383
414
  )
384
415
 
385
416
 
flwr/client/typing.py CHANGED
@@ -15,7 +15,7 @@
15
15
  """Custom types for Flower clients."""
16
16
 
17
17
 
18
- from typing import Callable, Optional
18
+ from typing import Callable
19
19
 
20
20
  from flwr.common import Context, Message
21
21
 
@@ -23,7 +23,7 @@ from .client import Client as Client
23
23
 
24
24
  # Compatibility
25
25
  ClientFn = Callable[[str], Client]
26
- ClientFnExt = Callable[[int, Optional[int]], Client]
26
+ ClientFnExt = Callable[[Context], Client]
27
27
 
28
28
  ClientAppCallable = Callable[[Message, Context], Message]
29
29
  Mod = Callable[[Message, Context, ClientAppCallable], Message]
flwr/common/config.py CHANGED
@@ -16,12 +16,13 @@
16
16
 
17
17
  import os
18
18
  from pathlib import Path
19
- from typing import Any, Dict, Optional, Union
19
+ from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args
20
20
 
21
21
  import tomli
22
22
 
23
23
  from flwr.cli.config_utils import validate_fields
24
24
  from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
25
+ from flwr.common.typing import Run, UserConfig, UserConfigValue
25
26
 
26
27
 
27
28
  def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
@@ -30,7 +31,7 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
30
31
  return Path(
31
32
  os.getenv(
32
33
  FLWR_HOME,
33
- f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr",
34
+ Path(f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}") / ".flwr",
34
35
  )
35
36
  )
36
37
  return Path(provided_path).absolute()
@@ -71,3 +72,87 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]:
71
72
  )
72
73
 
73
74
  return config
75
+
76
+
77
+ def _fuse_dicts(
78
+ main_dict: UserConfig,
79
+ override_dict: UserConfig,
80
+ ) -> UserConfig:
81
+ fused_dict = main_dict.copy()
82
+
83
+ for key, value in override_dict.items():
84
+ if key in main_dict:
85
+ fused_dict[key] = value
86
+
87
+ return fused_dict
88
+
89
+
90
+ def get_fused_config_from_dir(
91
+ project_dir: Path, override_config: UserConfig
92
+ ) -> UserConfig:
93
+ """Merge the overrides from a given dict with the config from a Flower App."""
94
+ default_config = get_project_config(project_dir)["tool"]["flwr"]["app"].get(
95
+ "config", {}
96
+ )
97
+ flat_default_config = flatten_dict(default_config)
98
+
99
+ return _fuse_dicts(flat_default_config, override_config)
100
+
101
+
102
+ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig:
103
+ """Merge the overrides from a `Run` with the config from a FAB.
104
+
105
+ Get the config using the fab_id and the fab_version, remove the nesting by adding
106
+ the nested keys as prefixes separated by dots, and fuse it with the override dict.
107
+ """
108
+ if not run.fab_id or not run.fab_version:
109
+ return {}
110
+
111
+ project_dir = get_project_dir(run.fab_id, run.fab_version, flwr_dir)
112
+
113
+ return get_fused_config_from_dir(project_dir, run.override_config)
114
+
115
+
116
+ def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> UserConfig:
117
+ """Flatten dict by joining nested keys with a given separator."""
118
+ items: List[Tuple[str, UserConfigValue]] = []
119
+ separator: str = "."
120
+ for k, v in raw_dict.items():
121
+ new_key = f"{parent_key}{separator}{k}" if parent_key else k
122
+ if isinstance(v, dict):
123
+ items.extend(flatten_dict(v, parent_key=new_key).items())
124
+ elif isinstance(v, get_args(UserConfigValue)):
125
+ items.append((new_key, cast(UserConfigValue, v)))
126
+ else:
127
+ raise ValueError(
128
+ f"The value for key {k} needs to be of type `int`, `float`, "
129
+ "`bool, `str`, or a `dict` of those.",
130
+ )
131
+ return dict(items)
132
+
133
+
134
+ def parse_config_args(
135
+ config: Optional[List[str]],
136
+ separator: str = ",",
137
+ ) -> UserConfig:
138
+ """Parse separator separated list of key-value pairs separated by '='."""
139
+ overrides: UserConfig = {}
140
+
141
+ if config is None:
142
+ return overrides
143
+
144
+ for config_line in config:
145
+ if config_line:
146
+ overrides_list = config_line.split(separator)
147
+ if (
148
+ len(overrides_list) == 1
149
+ and "=" not in overrides_list
150
+ and overrides_list[0].endswith(".toml")
151
+ ):
152
+ with Path(overrides_list[0]).open("rb") as config_file:
153
+ overrides = flatten_dict(tomli.load(config_file))
154
+ else:
155
+ toml_str = "\n".join(overrides_list)
156
+ overrides.update(tomli.loads(toml_str))
157
+
158
+ return overrides
flwr/common/constant.py CHANGED
@@ -57,6 +57,9 @@ APP_DIR = "apps"
57
57
  FAB_CONFIG_FILE = "pyproject.toml"
58
58
  FLWR_HOME = "FLWR_HOME"
59
59
 
60
+ # Constants entries in Node config for Simulation
61
+ PARTITION_ID_KEY = "partition-id"
62
+ NUM_PARTITIONS_KEY = "num-partitions"
60
63
 
61
64
  GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
62
65
  GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
flwr/common/context.py CHANGED
@@ -16,9 +16,9 @@
16
16
 
17
17
 
18
18
  from dataclasses import dataclass
19
- from typing import Optional
20
19
 
21
20
  from .record import RecordSet
21
+ from .typing import UserConfig
22
22
 
23
23
 
24
24
  @dataclass
@@ -27,6 +27,11 @@ class Context:
27
27
 
28
28
  Parameters
29
29
  ----------
30
+ node_id : int
31
+ The ID that identifies the node.
32
+ node_config : UserConfig
33
+ A config (key/value mapping) unique to the node and independent of the
34
+ `run_config`. This config persists across all runs this node participates in.
30
35
  state : RecordSet
31
36
  Holds records added by the entity in a given run and that will stay local.
32
37
  This means that the data it holds will never leave the system it's running from.
@@ -34,15 +39,25 @@ class Context:
34
39
  executing mods. It can also be used as a memory to access
35
40
  at different points during the lifecycle of this entity (e.g. across
36
41
  multiple rounds)
37
- partition_id : Optional[int] (default: None)
38
- An index that specifies the data partition that the ClientApp using this Context
39
- object should make use of. Setting this attribute is better suited for
40
- simulation or proto typing setups.
42
+ run_config : UserConfig
43
+ A config (key/value mapping) held by the entity in a given run and that will
44
+ stay local. It can be used at any point during the lifecycle of this entity
45
+ (e.g. across multiple rounds)
41
46
  """
42
47
 
48
+ node_id: int
49
+ node_config: UserConfig
43
50
  state: RecordSet
44
- partition_id: Optional[int]
45
-
46
- def __init__(self, state: RecordSet, partition_id: Optional[int] = None) -> None:
51
+ run_config: UserConfig
52
+
53
+ def __init__( # pylint: disable=too-many-arguments
54
+ self,
55
+ node_id: int,
56
+ node_config: UserConfig,
57
+ state: RecordSet,
58
+ run_config: UserConfig,
59
+ ) -> None:
60
+ self.node_id = node_id
61
+ self.node_config = node_config
47
62
  self.state = state
48
- self.partition_id = partition_id
63
+ self.run_config = run_config
flwr/common/logger.py CHANGED
@@ -197,6 +197,31 @@ def warn_deprecated_feature(name: str) -> None:
197
197
  )
198
198
 
199
199
 
200
+ def warn_deprecated_feature_with_example(
201
+ deprecation_message: str, example_message: str, code_example: str
202
+ ) -> None:
203
+ """Warn if a feature is deprecated and show code example."""
204
+ log(
205
+ WARN,
206
+ """DEPRECATED FEATURE: %s
207
+
208
+ Check the following `FEATURE UPDATE` warning message for the preferred
209
+ new mechanism to use this feature in Flower.
210
+ """,
211
+ deprecation_message,
212
+ )
213
+ log(
214
+ WARN,
215
+ """FEATURE UPDATE: %s
216
+ ------------------------------------------------------------
217
+ %s
218
+ ------------------------------------------------------------
219
+ """,
220
+ example_message,
221
+ code_example,
222
+ )
223
+
224
+
200
225
  def warn_unsupported_feature(name: str) -> None:
201
226
  """Warn the user when they use an unsupported feature."""
202
227
  log(
flwr/common/serde.py CHANGED
@@ -671,3 +671,48 @@ def message_from_taskres(taskres: TaskRes) -> Message:
671
671
  )
672
672
  message.metadata.created_at = taskres.task.created_at
673
673
  return message
674
+
675
+
676
+ # === User configs ===
677
+
678
+
679
+ def user_config_to_proto(user_config: typing.UserConfig) -> Any:
680
+ """Serialize `UserConfig` to ProtoBuf."""
681
+ proto = {}
682
+ for key, value in user_config.items():
683
+ proto[key] = user_config_value_to_proto(value)
684
+ return proto
685
+
686
+
687
+ def user_config_from_proto(proto: Any) -> typing.UserConfig:
688
+ """Deserialize `UserConfig` from ProtoBuf."""
689
+ metrics = {}
690
+ for key, value in proto.items():
691
+ metrics[key] = user_config_value_from_proto(value)
692
+ return metrics
693
+
694
+
695
+ def user_config_value_to_proto(user_config_value: typing.UserConfigValue) -> Scalar:
696
+ """Serialize `UserConfigValue` to ProtoBuf."""
697
+ if isinstance(user_config_value, bool):
698
+ return Scalar(bool=user_config_value)
699
+
700
+ if isinstance(user_config_value, float):
701
+ return Scalar(double=user_config_value)
702
+
703
+ if isinstance(user_config_value, int):
704
+ return Scalar(sint64=user_config_value)
705
+
706
+ if isinstance(user_config_value, str):
707
+ return Scalar(string=user_config_value)
708
+
709
+ raise ValueError(
710
+ f"Accepted types: {bool, float, int, str} (but not {type(user_config_value)})"
711
+ )
712
+
713
+
714
+ def user_config_value_from_proto(scalar_msg: Scalar) -> typing.UserConfigValue:
715
+ """Deserialize `UserConfigValue` from ProtoBuf."""
716
+ scalar_field = scalar_msg.WhichOneof("scalar")
717
+ scalar = getattr(scalar_msg, cast(str, scalar_field))
718
+ return cast(typing.UserConfigValue, scalar)
flwr/common/telemetry.py CHANGED
@@ -64,6 +64,18 @@ def _get_home() -> Path:
64
64
  return Path().home()
65
65
 
66
66
 
67
+ def _get_partner_id() -> str:
68
+ """Get partner ID."""
69
+ partner_id = os.getenv("FLWR_TELEMETRY_PARTNER_ID")
70
+ if not partner_id:
71
+ return "unavailable"
72
+ try:
73
+ uuid.UUID(partner_id)
74
+ except ValueError:
75
+ partner_id = "invalid"
76
+ return partner_id
77
+
78
+
67
79
  def _get_source_id() -> str:
68
80
  """Get existing or new source ID."""
69
81
  source_id = "unavailable"
@@ -177,6 +189,7 @@ state: Dict[str, Union[Optional[str], Optional[ThreadPoolExecutor]]] = {
177
189
  "executor": None,
178
190
  "source": None,
179
191
  "cluster": None,
192
+ "partner": None,
180
193
  }
181
194
 
182
195
 
@@ -202,11 +215,15 @@ def create_event(event_type: EventType, event_details: Optional[Dict[str, Any]])
202
215
  if state["cluster"] is None:
203
216
  state["cluster"] = str(uuid.uuid4())
204
217
 
218
+ if state["partner"] is None:
219
+ state["partner"] = _get_partner_id()
220
+
205
221
  if event_details is None:
206
222
  event_details = {}
207
223
 
208
224
  date = datetime.datetime.now(tz=datetime.timezone.utc).isoformat()
209
225
  context = {
226
+ "partner": state["partner"],
210
227
  "source": state["source"],
211
228
  "cluster": state["cluster"],
212
229
  "date": date,