flwr-nightly 1.11.0.dev20240823__py3-none-any.whl → 1.11.1.dev20240912__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (61) hide show
  1. flwr/cli/app.py +0 -2
  2. flwr/cli/new/new.py +41 -40
  3. flwr/cli/new/templates/app/LICENSE.tpl +202 -0
  4. flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
  5. flwr/cli/new/templates/app/README.flowertune.md.tpl +16 -6
  6. flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
  7. flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
  8. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
  9. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
  10. flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
  11. flwr/cli/new/templates/app/code/flwr_tune/{client.py.tpl → client_app.py.tpl} +50 -40
  12. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +32 -2
  13. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -3
  14. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +95 -0
  15. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
  16. flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
  17. flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
  18. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
  19. flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
  20. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
  21. flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
  22. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
  23. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +34 -7
  24. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
  25. flwr/cli/run/run.py +2 -2
  26. flwr/client/__init__.py +0 -4
  27. flwr/client/app.py +3 -4
  28. flwr/client/client_app.py +2 -2
  29. flwr/client/grpc_rere_client/client_interceptor.py +15 -7
  30. flwr/client/supernode/app.py +8 -7
  31. flwr/common/config.py +14 -11
  32. flwr/common/constant.py +12 -1
  33. flwr/common/record/recordset.py +1 -1
  34. flwr/common/record/typeddict.py +24 -1
  35. flwr/common/telemetry.py +36 -30
  36. flwr/server/__init__.py +0 -4
  37. flwr/server/app.py +21 -22
  38. flwr/server/compat/app.py +0 -5
  39. flwr/server/driver/grpc_driver.py +3 -6
  40. flwr/server/run_serverapp.py +20 -7
  41. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +19 -8
  42. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +13 -12
  43. flwr/server/superlink/fleet/vce/backend/raybackend.py +21 -12
  44. flwr/server/superlink/state/in_memory_state.py +15 -15
  45. flwr/server/superlink/state/sqlite_state.py +10 -10
  46. flwr/server/superlink/state/state.py +8 -8
  47. flwr/simulation/ray_transport/ray_actor.py +2 -2
  48. flwr/simulation/run_simulation.py +37 -8
  49. flwr/superexec/__init__.py +0 -6
  50. flwr/superexec/app.py +5 -3
  51. flwr/superexec/deployment.py +2 -2
  52. {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/METADATA +3 -3
  53. {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/RECORD +56 -48
  54. flwr_nightly-1.11.1.dev20240912.dist-info/entry_points.txt +10 -0
  55. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +0 -89
  56. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +0 -34
  57. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +0 -48
  58. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +0 -11
  59. flwr_nightly-1.11.0.dev20240823.dist-info/entry_points.txt +0 -10
  60. {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/LICENSE +0 -0
  61. {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/WHEEL +0 -0
@@ -4,24 +4,25 @@ import warnings
4
4
  from collections import OrderedDict
5
5
 
6
6
  import torch
7
+ import transformers
8
+ from datasets.utils.logging import disable_progress_bar
7
9
  from evaluate import load as load_metric
10
+ from flwr_datasets import FederatedDataset
11
+ from flwr_datasets.partitioner import IidPartitioner
8
12
  from torch.optim import AdamW
9
13
  from torch.utils.data import DataLoader
10
14
  from transformers import AutoTokenizer, DataCollatorWithPadding
11
15
 
12
- from flwr_datasets import FederatedDataset
13
- from flwr_datasets.partitioner import IidPartitioner
14
-
15
-
16
16
  warnings.filterwarnings("ignore", category=UserWarning)
17
- DEVICE = torch.device("cpu")
18
- CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
17
+ warnings.filterwarnings("ignore", category=FutureWarning)
18
+ disable_progress_bar()
19
+ transformers.logging.set_verbosity_error()
19
20
 
20
21
 
21
22
  fds = None # Cache FederatedDataset
22
23
 
23
24
 
24
- def load_data(partition_id: int, num_partitions: int):
25
+ def load_data(partition_id: int, num_partitions: int, model_name: str):
25
26
  """Load IMDB data (training and eval)"""
26
27
  # Only initialize `FederatedDataset` once
27
28
  global fds
@@ -35,10 +36,12 @@ def load_data(partition_id: int, num_partitions: int):
35
36
  # Divide data: 80% train, 20% test
36
37
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
37
38
 
38
- tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
39
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
39
40
 
40
41
  def tokenize_function(examples):
41
- return tokenizer(examples["text"], truncation=True)
42
+ return tokenizer(
43
+ examples["text"], truncation=True, add_special_tokens=True, max_length=512
44
+ )
42
45
 
43
46
  partition_train_test = partition_train_test.map(tokenize_function, batched=True)
44
47
  partition_train_test = partition_train_test.remove_columns("text")
@@ -59,12 +62,12 @@ def load_data(partition_id: int, num_partitions: int):
59
62
  return trainloader, testloader
60
63
 
61
64
 
62
- def train(net, trainloader, epochs):
65
+ def train(net, trainloader, epochs, device):
63
66
  optimizer = AdamW(net.parameters(), lr=5e-5)
64
67
  net.train()
65
68
  for _ in range(epochs):
66
69
  for batch in trainloader:
67
- batch = {k: v.to(DEVICE) for k, v in batch.items()}
70
+ batch = {k: v.to(device) for k, v in batch.items()}
68
71
  outputs = net(**batch)
69
72
  loss = outputs.loss
70
73
  loss.backward()
@@ -72,12 +75,12 @@ def train(net, trainloader, epochs):
72
75
  optimizer.zero_grad()
73
76
 
74
77
 
75
- def test(net, testloader):
78
+ def test(net, testloader, device):
76
79
  metric = load_metric("accuracy")
77
80
  loss = 0
78
81
  net.eval()
79
82
  for batch in testloader:
80
- batch = {k: v.to(DEVICE) for k, v in batch.items()}
83
+ batch = {k: v.to(device) for k, v in batch.items()}
81
84
  with torch.no_grad():
82
85
  outputs = net(**batch)
83
86
  logits = outputs.logits
@@ -0,0 +1 @@
1
+ """$project_name: A Flower Baseline."""
@@ -0,0 +1,138 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "$package_name"
7
+ version = "1.0.0"
8
+ description = ""
9
+ license = "Apache-2.0"
10
+ dependencies = [
11
+ "flwr[simulation]>=1.11.0",
12
+ "flwr-datasets[vision]>=0.3.0",
13
+ "torch==2.2.1",
14
+ "torchvision==0.17.1",
15
+ ]
16
+
17
+ [tool.hatch.metadata]
18
+ allow-direct-references = true
19
+
20
+ [project.optional-dependencies]
21
+ dev = [
22
+ "isort==5.13.2",
23
+ "black==24.2.0",
24
+ "docformatter==1.7.5",
25
+ "mypy==1.8.0",
26
+ "pylint==3.2.6",
27
+ "flake8==5.0.4",
28
+ "pytest==6.2.4",
29
+ "pytest-watch==4.2.0",
30
+ "ruff==0.1.9",
31
+ "types-requests==2.31.0.20240125",
32
+ ]
33
+
34
+ [tool.isort]
35
+ profile = "black"
36
+ known_first_party = ["flwr"]
37
+
38
+ [tool.black]
39
+ line-length = 88
40
+ target-version = ["py38", "py39", "py310", "py311"]
41
+
42
+ [tool.pytest.ini_options]
43
+ minversion = "6.2"
44
+ addopts = "-qq"
45
+ testpaths = [
46
+ "flwr_baselines",
47
+ ]
48
+
49
+ [tool.mypy]
50
+ ignore_missing_imports = true
51
+ strict = false
52
+ plugins = "numpy.typing.mypy_plugin"
53
+
54
+ [tool.pylint."MESSAGES CONTROL"]
55
+ disable = "duplicate-code,too-few-public-methods,useless-import-alias"
56
+ good-names = "i,j,k,_,x,y,X,Y,K,N"
57
+ max-args = 10
58
+ max-attributes = 15
59
+ max-locals = 36
60
+ max-branches = 20
61
+ max-statements = 55
62
+
63
+ [tool.pylint.typecheck]
64
+ generated-members = "numpy.*, torch.*, tensorflow.*"
65
+
66
+ [[tool.mypy.overrides]]
67
+ module = [
68
+ "importlib.metadata.*",
69
+ "importlib_metadata.*",
70
+ ]
71
+ follow_imports = "skip"
72
+ follow_imports_for_stubs = true
73
+ disallow_untyped_calls = false
74
+
75
+ [[tool.mypy.overrides]]
76
+ module = "torch.*"
77
+ follow_imports = "skip"
78
+ follow_imports_for_stubs = true
79
+
80
+ [tool.docformatter]
81
+ wrap-summaries = 88
82
+ wrap-descriptions = 88
83
+
84
+ [tool.ruff]
85
+ target-version = "py38"
86
+ line-length = 88
87
+ select = ["D", "E", "F", "W", "B", "ISC", "C4"]
88
+ fixable = ["D", "E", "F", "W", "B", "ISC", "C4"]
89
+ ignore = ["B024", "B027"]
90
+ exclude = [
91
+ ".bzr",
92
+ ".direnv",
93
+ ".eggs",
94
+ ".git",
95
+ ".hg",
96
+ ".mypy_cache",
97
+ ".nox",
98
+ ".pants.d",
99
+ ".pytype",
100
+ ".ruff_cache",
101
+ ".svn",
102
+ ".tox",
103
+ ".venv",
104
+ "__pypackages__",
105
+ "_build",
106
+ "buck-out",
107
+ "build",
108
+ "dist",
109
+ "node_modules",
110
+ "venv",
111
+ "proto",
112
+ ]
113
+
114
+ [tool.ruff.pydocstyle]
115
+ convention = "numpy"
116
+
117
+ [tool.hatch.build.targets.wheel]
118
+ packages = ["."]
119
+
120
+ [tool.flwr.app]
121
+ publisher = "$username"
122
+
123
+ [tool.flwr.app.components]
124
+ serverapp = "$import_name.server_app:app"
125
+ clientapp = "$import_name.client_app:app"
126
+
127
+ [tool.flwr.app.config]
128
+ num-server-rounds = 3
129
+ fraction-fit = 0.5
130
+ local-epochs = 1
131
+
132
+ [tool.flwr.federations]
133
+ default = "local-simulation"
134
+
135
+ [tool.flwr.federations.local-simulation]
136
+ options.num-supernodes = 10
137
+ options.backend.client-resources.num-cpus = 2
138
+ options.backend.client-resources.num-gpus = 0.0
@@ -8,15 +8,15 @@ version = "1.0.0"
8
8
  description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
- "flwr[simulation]>=1.9.0,<2.0",
12
- "flwr-datasets>=0.1.0,<1.0.0",
13
- "hydra-core==1.3.2",
11
+ "flwr[simulation]>=1.10.0",
12
+ "flwr-datasets>=0.3.0",
14
13
  "trl==0.8.1",
15
14
  "bitsandbytes==0.43.0",
16
15
  "scipy==1.13.0",
17
16
  "peft==0.6.2",
18
17
  "transformers==4.39.3",
19
18
  "sentencepiece==0.2.0",
19
+ "omegaconf==2.3.0",
20
20
  ]
21
21
 
22
22
  [tool.hatch.build.targets.wheel]
@@ -26,14 +26,41 @@ packages = ["."]
26
26
  publisher = "$username"
27
27
 
28
28
  [tool.flwr.app.components]
29
- serverapp = "$import_name.app:server"
30
- clientapp = "$import_name.app:client"
29
+ serverapp = "$import_name.server_app:app"
30
+ clientapp = "$import_name.client_app:app"
31
31
 
32
32
  [tool.flwr.app.config]
33
- num-server-rounds = 3
33
+ model.name = "mistralai/Mistral-7B-v0.3"
34
+ model.quantization = 4
35
+ model.gradient-checkpointing = true
36
+ model.lora.peft-lora-r = 32
37
+ model.lora.peft-lora-alpha = 64
38
+ train.save-every-round = 5
39
+ train.learning-rate-max = 5e-5
40
+ train.learning-rate-min = 1e-6
41
+ train.seq-length = 512
42
+ train.training-arguments.output-dir = ""
43
+ train.training-arguments.learning-rate = ""
44
+ train.training-arguments.per-device-train-batch-size = 16
45
+ train.training-arguments.gradient-accumulation-steps = 1
46
+ train.training-arguments.logging-steps = 10
47
+ train.training-arguments.num-train-epochs = 3
48
+ train.training-arguments.max-steps = 10
49
+ train.training-arguments.save-steps = 1000
50
+ train.training-arguments.save-total-limit = 10
51
+ train.training-arguments.gradient-checkpointing = true
52
+ train.training-arguments.lr-scheduler-type = "constant"
53
+ strategy.fraction-fit = $fraction_fit
54
+ strategy.fraction-evaluate = 0.0
55
+ num-server-rounds = 200
56
+
57
+ [tool.flwr.app.config.static]
58
+ dataset.name = "$dataset_name"
34
59
 
35
60
  [tool.flwr.federations]
36
61
  default = "local-simulation"
37
62
 
38
63
  [tool.flwr.federations.local-simulation]
39
- options.num-supernodes = 10
64
+ options.num-supernodes = $num_clients
65
+ options.backend.client-resources.num-cpus = 6
66
+ options.backend.client-resources.num-gpus = 1.0
@@ -8,7 +8,7 @@ version = "1.0.0"
8
8
  description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
- "flwr[simulation]>=1.10.0",
11
+ "flwr[simulation]>=1.11.0",
12
12
  "flwr-datasets>=0.3.0",
13
13
  "torch==2.2.1",
14
14
  "transformers>=4.30.0,<5.0",
@@ -29,10 +29,18 @@ clientapp = "$import_name.client_app:app"
29
29
 
30
30
  [tool.flwr.app.config]
31
31
  num-server-rounds = 3
32
+ fraction-fit = 0.5
32
33
  local-epochs = 1
34
+ model-name = "prajjwal1/bert-tiny" # Set a larger model if you have access to more GPU resources
35
+ num-labels = 2
33
36
 
34
37
  [tool.flwr.federations]
35
38
  default = "localhost"
36
39
 
37
40
  [tool.flwr.federations.localhost]
38
41
  options.num-supernodes = 10
42
+
43
+ [tool.flwr.federations.localhost-gpu]
44
+ options.num-supernodes = 10
45
+ options.backend.client-resources.num-cpus = 4 # each ClientApp assumes to use 4CPUs
46
+ options.backend.client-resources.num-gpus = 0.25 # at most 4 ClientApps will run in a given GPU
flwr/cli/run/run.py CHANGED
@@ -124,14 +124,14 @@ def run(
124
124
 
125
125
 
126
126
  def _run_with_superexec(
127
- app: Optional[Path],
127
+ app: Path,
128
128
  federation_config: Dict[str, Any],
129
129
  config_overrides: Optional[List[str]],
130
130
  ) -> None:
131
131
 
132
132
  insecure_str = federation_config.get("insecure")
133
133
  if root_certificates := federation_config.get("root-certificates"):
134
- root_certificates_bytes = Path(root_certificates).read_bytes()
134
+ root_certificates_bytes = (app / root_certificates).read_bytes()
135
135
  if insecure := bool(insecure_str):
136
136
  typer.secho(
137
137
  "❌ `root_certificates` were provided but the `insecure` parameter"
flwr/client/__init__.py CHANGED
@@ -20,8 +20,6 @@ from .app import start_numpy_client as start_numpy_client
20
20
  from .client import Client as Client
21
21
  from .client_app import ClientApp as ClientApp
22
22
  from .numpy_client import NumPyClient as NumPyClient
23
- from .supernode import run_client_app as run_client_app
24
- from .supernode import run_supernode as run_supernode
25
23
  from .typing import ClientFn as ClientFn
26
24
  from .typing import ClientFnExt as ClientFnExt
27
25
 
@@ -32,8 +30,6 @@ __all__ = [
32
30
  "ClientFnExt",
33
31
  "NumPyClient",
34
32
  "mod",
35
- "run_client_app",
36
- "run_supernode",
37
33
  "start_client",
38
34
  "start_numpy_client",
39
35
  ]
flwr/client/app.py CHANGED
@@ -35,6 +35,7 @@ from flwr.client.typing import ClientFnExt
35
35
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
36
36
  from flwr.common.address import parse_address
37
37
  from flwr.common.constant import (
38
+ CLIENTAPPIO_API_DEFAULT_ADDRESS,
38
39
  MISSING_EXTRA_REST,
39
40
  RUN_ID_NUM_BYTES,
40
41
  TRANSPORT_TYPE_GRPC_ADAPTER,
@@ -60,8 +61,6 @@ from .message_handler.message_handler import handle_control_message
60
61
  from .node_state import NodeState
61
62
  from .numpy_client import NumPyClient
62
63
 
63
- ADDRESS_CLIENTAPPIO_API_GRPC_RERE = "0.0.0.0:9094"
64
-
65
64
  ISOLATION_MODE_SUBPROCESS = "subprocess"
66
65
  ISOLATION_MODE_PROCESS = "process"
67
66
 
@@ -211,7 +210,7 @@ def start_client_internal(
211
210
  max_wait_time: Optional[float] = None,
212
211
  flwr_path: Optional[Path] = None,
213
212
  isolation: Optional[str] = None,
214
- supernode_address: Optional[str] = ADDRESS_CLIENTAPPIO_API_GRPC_RERE,
213
+ supernode_address: Optional[str] = CLIENTAPPIO_API_DEFAULT_ADDRESS,
215
214
  ) -> None:
216
215
  """Start a Flower client node which connects to a Flower server.
217
216
 
@@ -266,7 +265,7 @@ def start_client_internal(
266
265
  by the SueprNode and communicates using gRPC at the address
267
266
  `supernode_address`. If `process`, the `ClientApp` runs in a separate isolated
268
267
  process and communicates using gRPC at the address `supernode_address`.
269
- supernode_address : Optional[str] (default: `ADDRESS_CLIENTAPPIO_API_GRPC_RERE`)
268
+ supernode_address : Optional[str] (default: `CLIENTAPPIO_API_DEFAULT_ADDRESS`)
270
269
  The SuperNode gRPC server address.
271
270
  """
272
271
  if insecure is None:
flwr/client/client_app.py CHANGED
@@ -41,11 +41,11 @@ def _alert_erroneous_client_fn() -> None:
41
41
 
42
42
  def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
43
43
  client_fn_args = inspect.signature(client_fn).parameters
44
- first_arg = list(client_fn_args.keys())[0]
45
44
 
46
45
  if len(client_fn_args) != 1:
47
46
  _alert_erroneous_client_fn()
48
47
 
48
+ first_arg = list(client_fn_args.keys())[0]
49
49
  first_arg_type = client_fn_args[first_arg].annotation
50
50
 
51
51
  if first_arg_type is str or first_arg == "cid":
@@ -263,7 +263,7 @@ def _registration_error(fn_name: str) -> ValueError:
263
263
  >>> class FlowerClient(NumPyClient):
264
264
  >>> # ...
265
265
  >>>
266
- >>> def client_fn(cid) -> Client:
266
+ >>> def client_fn(context: Context):
267
267
  >>> return FlowerClient().to_client()
268
268
  >>>
269
269
  >>> app = ClientApp(
@@ -17,11 +17,13 @@
17
17
 
18
18
  import base64
19
19
  import collections
20
+ from logging import WARNING
20
21
  from typing import Any, Callable, Optional, Sequence, Tuple, Union
21
22
 
22
23
  import grpc
23
24
  from cryptography.hazmat.primitives.asymmetric import ec
24
25
 
26
+ from flwr.common.logger import log
25
27
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
26
28
  bytes_to_public_key,
27
29
  compute_hmac,
@@ -128,13 +130,12 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
128
130
  if self.shared_secret is None:
129
131
  raise RuntimeError("Failure to compute hmac")
130
132
 
133
+ message_bytes = request.SerializeToString(deterministic=True)
131
134
  metadata.append(
132
135
  (
133
136
  _AUTH_TOKEN_HEADER,
134
137
  base64.urlsafe_b64encode(
135
- compute_hmac(
136
- self.shared_secret, request.SerializeToString(True)
137
- )
138
+ compute_hmac(self.shared_secret, message_bytes)
138
139
  ),
139
140
  )
140
141
  )
@@ -151,8 +152,15 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
151
152
  server_public_key_bytes = base64.urlsafe_b64decode(
152
153
  _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
153
154
  )
154
- self.server_public_key = bytes_to_public_key(server_public_key_bytes)
155
- self.shared_secret = generate_shared_key(
156
- self.private_key, self.server_public_key
157
- )
155
+
156
+ if server_public_key_bytes != b"":
157
+ self.server_public_key = bytes_to_public_key(server_public_key_bytes)
158
+ else:
159
+ log(WARNING, "Can't get server public key, SuperLink may be offline")
160
+
161
+ if self.server_public_key is not None:
162
+ self.shared_secret = generate_shared_key(
163
+ self.private_key, self.server_public_key
164
+ )
165
+
158
166
  return response
@@ -30,6 +30,7 @@ from cryptography.hazmat.primitives.serialization import (
30
30
  from flwr.common import EventType, event
31
31
  from flwr.common.config import parse_config_args
32
32
  from flwr.common.constant import (
33
+ FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
33
34
  TRANSPORT_TYPE_GRPC_ADAPTER,
34
35
  TRANSPORT_TYPE_GRPC_RERE,
35
36
  TRANSPORT_TYPE_REST,
@@ -44,8 +45,6 @@ from ..app import (
44
45
  )
45
46
  from ..clientapp.utils import get_load_client_app_fn
46
47
 
47
- ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092"
48
-
49
48
 
50
49
  def run_supernode() -> None:
51
50
  """Run Flower SuperNode."""
@@ -77,7 +76,9 @@ def run_supernode() -> None:
77
76
  authentication_keys=authentication_keys,
78
77
  max_retries=args.max_retries,
79
78
  max_wait_time=args.max_wait_time,
80
- node_config=parse_config_args([args.node_config]),
79
+ node_config=parse_config_args(
80
+ [args.node_config] if args.node_config else args.node_config
81
+ ),
81
82
  isolation=args.isolation,
82
83
  supernode_address=args.supernode_address,
83
84
  )
@@ -101,11 +102,11 @@ def run_client_app() -> None:
101
102
 
102
103
  def _warn_deprecated_server_arg(args: argparse.Namespace) -> None:
103
104
  """Warn about the deprecated argument `--server`."""
104
- if args.server != ADDRESS_FLEET_API_GRPC_RERE:
105
+ if args.server != FLEET_API_GRPC_RERE_DEFAULT_ADDRESS:
105
106
  warn = "Passing flag --server is deprecated. Use --superlink instead."
106
107
  warn_deprecated_feature(warn)
107
108
 
108
- if args.superlink != ADDRESS_FLEET_API_GRPC_RERE:
109
+ if args.superlink != FLEET_API_GRPC_RERE_DEFAULT_ADDRESS:
109
110
  # if `--superlink` also passed, then
110
111
  # warn user that this argument overrides what was passed with `--server`
111
112
  log(
@@ -245,12 +246,12 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
245
246
  )
246
247
  parser.add_argument(
247
248
  "--server",
248
- default=ADDRESS_FLEET_API_GRPC_RERE,
249
+ default=FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
249
250
  help="Server address",
250
251
  )
251
252
  parser.add_argument(
252
253
  "--superlink",
253
- default=ADDRESS_FLEET_API_GRPC_RERE,
254
+ default=FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
254
255
  help="SuperLink Fleet API (gRPC-rere) address (IPv4, IPv6, or a domain name)",
255
256
  )
256
257
  parser.add_argument(
flwr/common/config.py CHANGED
@@ -185,23 +185,26 @@ def parse_config_args(
185
185
  if config is None:
186
186
  return overrides
187
187
 
188
+ # Handle if .toml file is passed
189
+ if len(config) == 1 and config[0].endswith(".toml"):
190
+ with Path(config[0]).open("rb") as config_file:
191
+ overrides = flatten_dict(tomli.load(config_file))
192
+ return overrides
193
+
188
194
  # Regular expression to capture key-value pairs with possible quoted values
189
195
  pattern = re.compile(r"(\S+?)=(\'[^\']*\'|\"[^\"]*\"|\S+)")
190
196
 
191
197
  for config_line in config:
192
198
  if config_line:
193
- matches = pattern.findall(config_line)
199
+ # .toml files aren't allowed alongside other configs
200
+ if config_line.endswith(".toml"):
201
+ raise ValueError(
202
+ "TOML files cannot be passed alongside key-value pairs."
203
+ )
194
204
 
195
- if (
196
- len(matches) == 1
197
- and "=" not in matches[0][0]
198
- and matches[0][0].endswith(".toml")
199
- ):
200
- with Path(matches[0][0]).open("rb") as config_file:
201
- overrides = flatten_dict(tomli.load(config_file))
202
- else:
203
- toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
204
- overrides.update(tomli.loads(toml_str))
205
+ matches = pattern.findall(config_line)
206
+ toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
207
+ overrides.update(tomli.loads(toml_str))
205
208
 
206
209
  return overrides
207
210
 
flwr/common/constant.py CHANGED
@@ -37,7 +37,18 @@ TRANSPORT_TYPES = [
37
37
  TRANSPORT_TYPE_VCE,
38
38
  ]
39
39
 
40
- SUPEREXEC_DEFAULT_ADDRESS = "0.0.0.0:9093"
40
+ # Addresses
41
+ # SuperNode
42
+ CLIENTAPPIO_API_DEFAULT_ADDRESS = "0.0.0.0:9094"
43
+ # SuperExec
44
+ EXEC_API_DEFAULT_ADDRESS = "0.0.0.0:9093"
45
+ # SuperLink
46
+ DRIVER_API_DEFAULT_ADDRESS = "0.0.0.0:9091"
47
+ FLEET_API_GRPC_RERE_DEFAULT_ADDRESS = "0.0.0.0:9092"
48
+ FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS = (
49
+ "[::]:8080" # IPv6 to keep start_server compatible
50
+ )
51
+ FLEET_API_REST_DEFAULT_ADDRESS = "0.0.0.0:9093"
41
52
 
42
53
  # Constants for ping
43
54
  PING_DEFAULT_INTERVAL = 30
@@ -119,7 +119,7 @@ class RecordSet:
119
119
  Let's see an example.
120
120
 
121
121
  >>> from flwr.common import RecordSet
122
- >>> from flwr.common import ConfigsRecords, MetricsRecords, ParametersRecord
122
+ >>> from flwr.common import ConfigsRecord, MetricsRecord, ParametersRecord
123
123
  >>>
124
124
  >>> # Let's begin with an empty record
125
125
  >>> my_recordset = RecordSet()
@@ -15,7 +15,18 @@
15
15
  """Typed dict base class for *Records."""
16
16
 
17
17
 
18
- from typing import Callable, Dict, Generic, Iterator, MutableMapping, TypeVar, cast
18
+ from typing import (
19
+ Callable,
20
+ Dict,
21
+ Generic,
22
+ ItemsView,
23
+ Iterator,
24
+ KeysView,
25
+ MutableMapping,
26
+ TypeVar,
27
+ ValuesView,
28
+ cast,
29
+ )
19
30
 
20
31
  K = TypeVar("K") # Key type
21
32
  V = TypeVar("V") # Value type
@@ -73,3 +84,15 @@ class TypedDict(MutableMapping[K, V], Generic[K, V]):
73
84
  if isinstance(other, dict):
74
85
  return data == other
75
86
  return NotImplemented
87
+
88
+ def keys(self) -> KeysView[K]:
89
+ """D.keys() -> a set-like object providing a view on D's keys."""
90
+ return cast(Dict[K, V], self.__dict__["_data"]).keys()
91
+
92
+ def values(self) -> ValuesView[V]:
93
+ """D.values() -> an object providing a view on D's values."""
94
+ return cast(Dict[K, V], self.__dict__["_data"]).values()
95
+
96
+ def items(self) -> ItemsView[K, V]:
97
+ """D.items() -> a set-like object providing a view on D's items."""
98
+ return cast(Dict[K, V], self.__dict__["_data"]).items()