flwr-nightly 1.12.0.dev20240907__py3-none-any.whl → 1.12.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (118) hide show
  1. flwr/cli/build.py +1 -2
  2. flwr/cli/config_utils.py +10 -10
  3. flwr/cli/install.py +1 -2
  4. flwr/cli/new/new.py +26 -40
  5. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
  6. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
  7. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
  8. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
  9. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
  10. flwr/cli/run/run.py +6 -7
  11. flwr/cli/utils.py +2 -2
  12. flwr/client/app.py +14 -14
  13. flwr/client/client_app.py +5 -5
  14. flwr/client/clientapp/app.py +2 -2
  15. flwr/client/dpfedavg_numpy_client.py +6 -7
  16. flwr/client/grpc_adapter_client/connection.py +4 -3
  17. flwr/client/grpc_client/connection.py +4 -3
  18. flwr/client/grpc_rere_client/client_interceptor.py +5 -5
  19. flwr/client/grpc_rere_client/connection.py +5 -4
  20. flwr/client/grpc_rere_client/grpc_adapter.py +2 -2
  21. flwr/client/message_handler/message_handler.py +3 -3
  22. flwr/client/mod/secure_aggregation/secaggplus_mod.py +25 -25
  23. flwr/client/mod/utils.py +1 -3
  24. flwr/client/node_state.py +2 -2
  25. flwr/client/numpy_client.py +8 -8
  26. flwr/client/rest_client/connection.py +5 -4
  27. flwr/client/supernode/app.py +7 -8
  28. flwr/common/address.py +2 -2
  29. flwr/common/config.py +8 -8
  30. flwr/common/constant.py +12 -1
  31. flwr/common/differential_privacy.py +2 -2
  32. flwr/common/dp.py +1 -3
  33. flwr/common/exit_handlers.py +3 -3
  34. flwr/common/grpc.py +2 -1
  35. flwr/common/logger.py +3 -3
  36. flwr/common/object_ref.py +3 -3
  37. flwr/common/record/configsrecord.py +3 -3
  38. flwr/common/record/metricsrecord.py +3 -3
  39. flwr/common/record/parametersrecord.py +3 -2
  40. flwr/common/record/recordset.py +1 -1
  41. flwr/common/record/typeddict.py +23 -10
  42. flwr/common/recordset_compat.py +7 -5
  43. flwr/common/retry_invoker.py +6 -17
  44. flwr/common/secure_aggregation/crypto/shamir.py +10 -10
  45. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +2 -2
  46. flwr/common/secure_aggregation/ndarrays_arithmetic.py +16 -16
  47. flwr/common/secure_aggregation/quantization.py +7 -7
  48. flwr/common/secure_aggregation/secaggplus_utils.py +3 -5
  49. flwr/common/serde.py +11 -9
  50. flwr/common/telemetry.py +5 -5
  51. flwr/common/typing.py +19 -19
  52. flwr/common/version.py +2 -3
  53. flwr/server/app.py +18 -18
  54. flwr/server/client_manager.py +6 -6
  55. flwr/server/compat/app_utils.py +2 -3
  56. flwr/server/driver/driver.py +3 -2
  57. flwr/server/driver/grpc_driver.py +7 -7
  58. flwr/server/driver/inmemory_driver.py +5 -4
  59. flwr/server/history.py +8 -9
  60. flwr/server/run_serverapp.py +5 -6
  61. flwr/server/server.py +36 -36
  62. flwr/server/strategy/aggregate.py +13 -13
  63. flwr/server/strategy/bulyan.py +8 -8
  64. flwr/server/strategy/dp_adaptive_clipping.py +20 -20
  65. flwr/server/strategy/dp_fixed_clipping.py +19 -19
  66. flwr/server/strategy/dpfedavg_adaptive.py +6 -6
  67. flwr/server/strategy/dpfedavg_fixed.py +10 -10
  68. flwr/server/strategy/fault_tolerant_fedavg.py +11 -11
  69. flwr/server/strategy/fedadagrad.py +8 -8
  70. flwr/server/strategy/fedadam.py +8 -8
  71. flwr/server/strategy/fedavg.py +16 -16
  72. flwr/server/strategy/fedavg_android.py +16 -16
  73. flwr/server/strategy/fedavgm.py +8 -8
  74. flwr/server/strategy/fedmedian.py +4 -4
  75. flwr/server/strategy/fedopt.py +5 -5
  76. flwr/server/strategy/fedprox.py +6 -6
  77. flwr/server/strategy/fedtrimmedavg.py +8 -8
  78. flwr/server/strategy/fedxgb_bagging.py +11 -11
  79. flwr/server/strategy/fedxgb_cyclic.py +9 -9
  80. flwr/server/strategy/fedxgb_nn_avg.py +5 -5
  81. flwr/server/strategy/fedyogi.py +8 -8
  82. flwr/server/strategy/krum.py +8 -8
  83. flwr/server/strategy/qfedavg.py +15 -15
  84. flwr/server/strategy/strategy.py +10 -10
  85. flwr/server/superlink/driver/driver_grpc.py +2 -2
  86. flwr/server/superlink/driver/driver_servicer.py +6 -6
  87. flwr/server/superlink/ffs/disk_ffs.py +4 -4
  88. flwr/server/superlink/ffs/ffs.py +4 -4
  89. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -2
  90. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +2 -1
  91. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +2 -1
  92. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +9 -8
  93. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +5 -3
  94. flwr/server/superlink/fleet/message_handler/message_handler.py +2 -2
  95. flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -1
  96. flwr/server/superlink/fleet/vce/backend/__init__.py +2 -3
  97. flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
  98. flwr/server/superlink/fleet/vce/backend/raybackend.py +26 -17
  99. flwr/server/superlink/fleet/vce/vce_api.py +6 -6
  100. flwr/server/superlink/state/in_memory_state.py +18 -18
  101. flwr/server/superlink/state/sqlite_state.py +22 -21
  102. flwr/server/superlink/state/state.py +7 -7
  103. flwr/server/utils/tensorboard.py +4 -4
  104. flwr/server/utils/validator.py +2 -2
  105. flwr/server/workflow/default_workflows.py +5 -5
  106. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +22 -22
  107. flwr/simulation/app.py +8 -8
  108. flwr/simulation/ray_transport/ray_actor.py +23 -23
  109. flwr/simulation/run_simulation.py +16 -4
  110. flwr/superexec/app.py +4 -4
  111. flwr/superexec/deployment.py +2 -2
  112. flwr/superexec/exec_grpc.py +2 -2
  113. flwr/superexec/exec_servicer.py +3 -2
  114. {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/METADATA +4 -6
  115. {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/RECORD +118 -118
  116. {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/LICENSE +0 -0
  117. {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/WHEEL +0 -0
  118. {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240913.dist-info}/entry_points.txt +0 -0
flwr/cli/build.py CHANGED
@@ -17,12 +17,11 @@
17
17
  import os
18
18
  import zipfile
19
19
  from pathlib import Path
20
- from typing import Optional
20
+ from typing import Annotated, Optional
21
21
 
22
22
  import pathspec
23
23
  import tomli_w
24
24
  import typer
25
- from typing_extensions import Annotated
26
25
 
27
26
  from .config_utils import load_and_validate
28
27
  from .utils import get_sha256_hash, is_valid_project_name
flwr/cli/config_utils.py CHANGED
@@ -17,7 +17,7 @@
17
17
  import zipfile
18
18
  from io import BytesIO
19
19
  from pathlib import Path
20
- from typing import IO, Any, Dict, List, Optional, Tuple, Union, get_args
20
+ from typing import IO, Any, Optional, Union, get_args
21
21
 
22
22
  import tomli
23
23
 
@@ -25,7 +25,7 @@ from flwr.common import object_ref
25
25
  from flwr.common.typing import UserConfigValue
26
26
 
27
27
 
28
- def get_fab_config(fab_file: Union[Path, bytes]) -> Dict[str, Any]:
28
+ def get_fab_config(fab_file: Union[Path, bytes]) -> dict[str, Any]:
29
29
  """Extract the config from a FAB file or path.
30
30
 
31
31
  Parameters
@@ -62,7 +62,7 @@ def get_fab_config(fab_file: Union[Path, bytes]) -> Dict[str, Any]:
62
62
  return conf
63
63
 
64
64
 
65
- def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
65
+ def get_fab_metadata(fab_file: Union[Path, bytes]) -> tuple[str, str]:
66
66
  """Extract the fab_id and the fab_version from a FAB file or path.
67
67
 
68
68
  Parameters
@@ -87,7 +87,7 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
87
87
  def load_and_validate(
88
88
  path: Optional[Path] = None,
89
89
  check_module: bool = True,
90
- ) -> Tuple[Optional[Dict[str, Any]], List[str], List[str]]:
90
+ ) -> tuple[Optional[dict[str, Any]], list[str], list[str]]:
91
91
  """Load and validate pyproject.toml as dict.
92
92
 
93
93
  Returns
@@ -116,7 +116,7 @@ def load_and_validate(
116
116
  return (config, errors, warnings)
117
117
 
118
118
 
119
- def load(toml_path: Path) -> Optional[Dict[str, Any]]:
119
+ def load(toml_path: Path) -> Optional[dict[str, Any]]:
120
120
  """Load pyproject.toml and return as dict."""
121
121
  if not toml_path.is_file():
122
122
  return None
@@ -125,7 +125,7 @@ def load(toml_path: Path) -> Optional[Dict[str, Any]]:
125
125
  return load_from_string(toml_file.read())
126
126
 
127
127
 
128
- def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None:
128
+ def _validate_run_config(config_dict: dict[str, Any], errors: list[str]) -> None:
129
129
  for key, value in config_dict.items():
130
130
  if isinstance(value, dict):
131
131
  _validate_run_config(config_dict[key], errors)
@@ -137,7 +137,7 @@ def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None
137
137
 
138
138
 
139
139
  # pylint: disable=too-many-branches
140
- def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
140
+ def validate_fields(config: dict[str, Any]) -> tuple[bool, list[str], list[str]]:
141
141
  """Validate pyproject.toml fields."""
142
142
  errors = []
143
143
  warnings = []
@@ -183,10 +183,10 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]
183
183
 
184
184
 
185
185
  def validate(
186
- config: Dict[str, Any],
186
+ config: dict[str, Any],
187
187
  check_module: bool = True,
188
188
  project_dir: Optional[Union[str, Path]] = None,
189
- ) -> Tuple[bool, List[str], List[str]]:
189
+ ) -> tuple[bool, list[str], list[str]]:
190
190
  """Validate pyproject.toml."""
191
191
  is_valid, errors, warnings = validate_fields(config)
192
192
 
@@ -210,7 +210,7 @@ def validate(
210
210
  return True, [], []
211
211
 
212
212
 
213
- def load_from_string(toml_content: str) -> Optional[Dict[str, Any]]:
213
+ def load_from_string(toml_content: str) -> Optional[dict[str, Any]]:
214
214
  """Load TOML content from a string and return as dict."""
215
215
  try:
216
216
  data = tomli.loads(toml_content)
flwr/cli/install.py CHANGED
@@ -21,10 +21,9 @@ import tempfile
21
21
  import zipfile
22
22
  from io import BytesIO
23
23
  from pathlib import Path
24
- from typing import IO, Optional, Union
24
+ from typing import IO, Annotated, Optional, Union
25
25
 
26
26
  import typer
27
- from typing_extensions import Annotated
28
27
 
29
28
  from flwr.common.config import get_flwr_dir
30
29
 
flwr/cli/new/new.py CHANGED
@@ -18,10 +18,9 @@ import re
18
18
  from enum import Enum
19
19
  from pathlib import Path
20
20
  from string import Template
21
- from typing import Dict, Optional
21
+ from typing import Annotated, Optional
22
22
 
23
23
  import typer
24
- from typing_extensions import Annotated
25
24
 
26
25
  from ..utils import (
27
26
  is_valid_project_name,
@@ -70,7 +69,7 @@ def load_template(name: str) -> str:
70
69
  return tpl_file.read()
71
70
 
72
71
 
73
- def render_template(template: str, data: Dict[str, str]) -> str:
72
+ def render_template(template: str, data: dict[str, str]) -> str:
74
73
  """Render template."""
75
74
  tpl_file = load_template(template)
76
75
  tpl = Template(tpl_file)
@@ -85,7 +84,7 @@ def create_file(file_path: Path, content: str) -> None:
85
84
  file_path.write_text(content)
86
85
 
87
86
 
88
- def render_and_create(file_path: Path, template: str, context: Dict[str, str]) -> None:
87
+ def render_and_create(file_path: Path, template: str, context: dict[str, str]) -> None:
89
88
  """Render template and write to file."""
90
89
  content = render_template(template, context)
91
90
  create_file(file_path, content)
@@ -136,36 +135,23 @@ def new(
136
135
  username = prompt_text("Please provide your Flower username")
137
136
 
138
137
  if framework is not None:
139
- framework_str_upper = str(framework.value)
138
+ framework_str = str(framework.value)
140
139
  else:
141
- framework_value = prompt_options(
140
+ framework_str = prompt_options(
142
141
  "Please select ML framework by typing in the number",
143
142
  [mlf.value for mlf in MlFramework],
144
143
  )
145
- selected_value = [
146
- name
147
- for name, value in vars(MlFramework).items()
148
- if value == framework_value
149
- ]
150
- framework_str_upper = selected_value[0]
151
-
152
- framework_str = framework_str_upper.lower()
153
144
 
154
145
  llm_challenge_str = None
155
- if framework_str == "flowertune":
146
+ if framework_str == MlFramework.FLOWERTUNE:
156
147
  llm_challenge_value = prompt_options(
157
148
  "Please select LLM challenge by typing in the number",
158
149
  sorted([challenge.value for challenge in LlmChallengeName]),
159
150
  )
160
- selected_value = [
161
- name
162
- for name, value in vars(LlmChallengeName).items()
163
- if value == llm_challenge_value
164
- ]
165
- llm_challenge_str = selected_value[0]
166
- llm_challenge_str = llm_challenge_str.lower()
151
+ llm_challenge_str = llm_challenge_value.lower()
167
152
 
168
- is_baseline_project = framework_str == "baseline"
153
+ if framework_str == MlFramework.BASELINE:
154
+ framework_str = "baseline"
169
155
 
170
156
  print(
171
157
  typer.style(
@@ -176,19 +162,21 @@ def new(
176
162
  )
177
163
 
178
164
  context = {
179
- "framework_str": framework_str_upper,
165
+ "framework_str": framework_str,
180
166
  "import_name": import_name.replace("-", "_"),
181
167
  "package_name": package_name,
182
168
  "project_name": app_name,
183
169
  "username": username,
184
170
  }
185
171
 
172
+ template_name = framework_str.lower()
173
+
186
174
  # List of files to render
187
175
  if llm_challenge_str:
188
176
  files = {
189
177
  ".gitignore": {"template": "app/.gitignore.tpl"},
190
- "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
191
- "README.md": {"template": f"app/README.{framework_str}.md.tpl"},
178
+ "pyproject.toml": {"template": f"app/pyproject.{template_name}.toml.tpl"},
179
+ "README.md": {"template": f"app/README.{template_name}.md.tpl"},
192
180
  f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
193
181
  f"{import_name}/server_app.py": {
194
182
  "template": "app/code/flwr_tune/server_app.py.tpl"
@@ -235,44 +223,42 @@ def new(
235
223
  files = {
236
224
  ".gitignore": {"template": "app/.gitignore.tpl"},
237
225
  "README.md": {"template": "app/README.md.tpl"},
238
- "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
226
+ "pyproject.toml": {"template": f"app/pyproject.{template_name}.toml.tpl"},
239
227
  f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
240
228
  f"{import_name}/server_app.py": {
241
- "template": f"app/code/server.{framework_str}.py.tpl"
229
+ "template": f"app/code/server.{template_name}.py.tpl"
242
230
  },
243
231
  f"{import_name}/client_app.py": {
244
- "template": f"app/code/client.{framework_str}.py.tpl"
232
+ "template": f"app/code/client.{template_name}.py.tpl"
245
233
  },
246
234
  }
247
235
 
248
236
  # Depending on the framework, generate task.py file
249
237
  frameworks_with_tasks = [
250
- MlFramework.PYTORCH.value.lower(),
251
- MlFramework.JAX.value.lower(),
252
- MlFramework.HUGGINGFACE.value.lower(),
253
- MlFramework.MLX.value.lower(),
254
- MlFramework.TENSORFLOW.value.lower(),
238
+ MlFramework.PYTORCH.value,
239
+ MlFramework.JAX.value,
240
+ MlFramework.HUGGINGFACE.value,
241
+ MlFramework.MLX.value,
242
+ MlFramework.TENSORFLOW.value,
255
243
  ]
256
244
  if framework_str in frameworks_with_tasks:
257
245
  files[f"{import_name}/task.py"] = {
258
- "template": f"app/code/task.{framework_str}.py.tpl"
246
+ "template": f"app/code/task.{template_name}.py.tpl"
259
247
  }
260
248
 
261
- if is_baseline_project:
249
+ if framework_str == "baseline":
262
250
  # Include additional files for baseline template
263
251
  for file_name in ["model", "dataset", "strategy", "utils", "__init__"]:
264
252
  files[f"{import_name}/{file_name}.py"] = {
265
- "template": f"app/code/{file_name}.{framework_str}.py.tpl"
253
+ "template": f"app/code/{file_name}.{template_name}.py.tpl"
266
254
  }
267
255
 
268
256
  # Replace README.md
269
- files["README.md"]["template"] = f"app/README.{framework_str}.md.tpl"
257
+ files["README.md"]["template"] = f"app/README.{template_name}.md.tpl"
270
258
 
271
259
  # Add LICENSE
272
260
  files["LICENSE"] = {"template": "app/LICENSE.tpl"}
273
261
 
274
- context["framework_str"] = "baseline"
275
-
276
262
  for file_path, value in files.items():
277
263
  render_and_create(
278
264
  file_path=project_dir / file_path,
@@ -1,18 +1,11 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
+ import torch
3
4
  from flwr.client import ClientApp, NumPyClient
4
5
  from flwr.common import Context
5
6
  from transformers import AutoModelForSequenceClassification
6
7
 
7
- from $import_name.task import (
8
- get_weights,
9
- load_data,
10
- set_weights,
11
- train,
12
- test,
13
- CHECKPOINT,
14
- DEVICE,
15
- )
8
+ from $import_name.task import get_weights, load_data, set_weights, test, train
16
9
 
17
10
 
18
11
  # Flower client
@@ -22,37 +15,34 @@ class FlowerClient(NumPyClient):
22
15
  self.trainloader = trainloader
23
16
  self.testloader = testloader
24
17
  self.local_epochs = local_epochs
25
-
26
- def get_parameters(self, config):
27
- return get_weights(self.net)
28
-
29
- def set_parameters(self, parameters):
30
- set_weights(self.net, parameters)
18
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+ self.net.to(self.device)
31
20
 
32
21
  def fit(self, parameters, config):
33
- self.set_parameters(parameters)
34
- train(
35
- self.net,
36
- self.trainloader,
37
- epochs=self.local_epochs,
38
- )
39
- return self.get_parameters(config={}), len(self.trainloader), {}
22
+ set_weights(self.net, parameters)
23
+ train(self.net, self.trainloader, epochs=self.local_epochs, device=self.device)
24
+ return get_weights(self.net), len(self.trainloader), {}
40
25
 
41
26
  def evaluate(self, parameters, config):
42
- self.set_parameters(parameters)
43
- loss, accuracy = test(self.net, self.testloader)
27
+ set_weights(self.net, parameters)
28
+ loss, accuracy = test(self.net, self.testloader, self.device)
44
29
  return float(loss), len(self.testloader), {"accuracy": accuracy}
45
30
 
46
31
 
47
32
  def client_fn(context: Context):
48
- # Load model and data
49
- net = AutoModelForSequenceClassification.from_pretrained(
50
- CHECKPOINT, num_labels=2
51
- ).to(DEVICE)
52
33
 
34
+ # Get this client's dataset partition
53
35
  partition_id = context.node_config["partition-id"]
54
36
  num_partitions = context.node_config["num-partitions"]
55
- trainloader, valloader = load_data(partition_id, num_partitions)
37
+ model_name = context.run_config["model-name"]
38
+ trainloader, valloader = load_data(partition_id, num_partitions, model_name)
39
+
40
+ # Load model
41
+ num_labels = context.run_config["num-labels"]
42
+ net = AutoModelForSequenceClassification.from_pretrained(
43
+ model_name, num_labels=num_labels
44
+ )
45
+
56
46
  local_epochs = context.run_config["local-epochs"]
57
47
 
58
48
  # Return Client instance
@@ -17,9 +17,6 @@ class FlowerClient(NumPyClient):
17
17
  self.batch_size = batch_size
18
18
  self.verbose = verbose
19
19
 
20
- def get_parameters(self, config):
21
- return self.model.get_weights()
22
-
23
20
  def fit(self, parameters, config):
24
21
  self.model.set_weights(parameters)
25
22
  self.model.fit(
@@ -1,18 +1,33 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.common import Context
4
- from flwr.server.strategy import FedAvg
3
+ from flwr.common import Context, ndarrays_to_parameters
5
4
  from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
+ from flwr.server.strategy import FedAvg
6
+ from transformers import AutoModelForSequenceClassification
7
+
8
+ from $import_name.task import get_weights
6
9
 
7
10
 
8
11
  def server_fn(context: Context):
9
12
  # Read from config
10
13
  num_rounds = context.run_config["num-server-rounds"]
14
+ fraction_fit = context.run_config["fraction-fit"]
15
+
16
+ # Initialize global model
17
+ model_name = context.run_config["model-name"]
18
+ num_labels = context.run_config["num-labels"]
19
+ net = AutoModelForSequenceClassification.from_pretrained(
20
+ model_name, num_labels=num_labels
21
+ )
22
+
23
+ weights = get_weights(net)
24
+ initial_parameters = ndarrays_to_parameters(weights)
11
25
 
12
26
  # Define strategy
13
27
  strategy = FedAvg(
14
- fraction_fit=1.0,
28
+ fraction_fit=fraction_fit,
15
29
  fraction_evaluate=1.0,
30
+ initial_parameters=initial_parameters,
16
31
  )
17
32
  config = ServerConfig(num_rounds=num_rounds)
18
33
 
@@ -4,24 +4,25 @@ import warnings
4
4
  from collections import OrderedDict
5
5
 
6
6
  import torch
7
+ import transformers
8
+ from datasets.utils.logging import disable_progress_bar
7
9
  from evaluate import load as load_metric
10
+ from flwr_datasets import FederatedDataset
11
+ from flwr_datasets.partitioner import IidPartitioner
8
12
  from torch.optim import AdamW
9
13
  from torch.utils.data import DataLoader
10
14
  from transformers import AutoTokenizer, DataCollatorWithPadding
11
15
 
12
- from flwr_datasets import FederatedDataset
13
- from flwr_datasets.partitioner import IidPartitioner
14
-
15
-
16
16
  warnings.filterwarnings("ignore", category=UserWarning)
17
- DEVICE = torch.device("cpu")
18
- CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
17
+ warnings.filterwarnings("ignore", category=FutureWarning)
18
+ disable_progress_bar()
19
+ transformers.logging.set_verbosity_error()
19
20
 
20
21
 
21
22
  fds = None # Cache FederatedDataset
22
23
 
23
24
 
24
- def load_data(partition_id: int, num_partitions: int):
25
+ def load_data(partition_id: int, num_partitions: int, model_name: str):
25
26
  """Load IMDB data (training and eval)"""
26
27
  # Only initialize `FederatedDataset` once
27
28
  global fds
@@ -35,10 +36,12 @@ def load_data(partition_id: int, num_partitions: int):
35
36
  # Divide data: 80% train, 20% test
36
37
  partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
37
38
 
38
- tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
39
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
39
40
 
40
41
  def tokenize_function(examples):
41
- return tokenizer(examples["text"], truncation=True)
42
+ return tokenizer(
43
+ examples["text"], truncation=True, add_special_tokens=True, max_length=512
44
+ )
42
45
 
43
46
  partition_train_test = partition_train_test.map(tokenize_function, batched=True)
44
47
  partition_train_test = partition_train_test.remove_columns("text")
@@ -59,12 +62,12 @@ def load_data(partition_id: int, num_partitions: int):
59
62
  return trainloader, testloader
60
63
 
61
64
 
62
- def train(net, trainloader, epochs):
65
+ def train(net, trainloader, epochs, device):
63
66
  optimizer = AdamW(net.parameters(), lr=5e-5)
64
67
  net.train()
65
68
  for _ in range(epochs):
66
69
  for batch in trainloader:
67
- batch = {k: v.to(DEVICE) for k, v in batch.items()}
70
+ batch = {k: v.to(device) for k, v in batch.items()}
68
71
  outputs = net(**batch)
69
72
  loss = outputs.loss
70
73
  loss.backward()
@@ -72,12 +75,12 @@ def train(net, trainloader, epochs):
72
75
  optimizer.zero_grad()
73
76
 
74
77
 
75
- def test(net, testloader):
78
+ def test(net, testloader, device):
76
79
  metric = load_metric("accuracy")
77
80
  loss = 0
78
81
  net.eval()
79
82
  for batch in testloader:
80
- batch = {k: v.to(DEVICE) for k, v in batch.items()}
83
+ batch = {k: v.to(device) for k, v in batch.items()}
81
84
  with torch.no_grad():
82
85
  outputs = net(**batch)
83
86
  logits = outputs.logits
@@ -8,7 +8,7 @@ version = "1.0.0"
8
8
  description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
- "flwr[simulation]>=1.10.0",
11
+ "flwr[simulation]>=1.11.0",
12
12
  "flwr-datasets>=0.3.0",
13
13
  "torch==2.2.1",
14
14
  "transformers>=4.30.0,<5.0",
@@ -29,10 +29,18 @@ clientapp = "$import_name.client_app:app"
29
29
 
30
30
  [tool.flwr.app.config]
31
31
  num-server-rounds = 3
32
+ fraction-fit = 0.5
32
33
  local-epochs = 1
34
+ model-name = "prajjwal1/bert-tiny" # Set a larger model if you have access to more GPU resources
35
+ num-labels = 2
33
36
 
34
37
  [tool.flwr.federations]
35
38
  default = "localhost"
36
39
 
37
40
  [tool.flwr.federations.localhost]
38
41
  options.num-supernodes = 10
42
+
43
+ [tool.flwr.federations.localhost-gpu]
44
+ options.num-supernodes = 10
45
+ options.backend.client-resources.num-cpus = 4 # each ClientApp assumes to use 4CPUs
46
+ options.backend.client-resources.num-gpus = 0.25 # at most 4 ClientApps will run in a given GPU
flwr/cli/run/run.py CHANGED
@@ -20,10 +20,9 @@ import subprocess
20
20
  import sys
21
21
  from logging import DEBUG
22
22
  from pathlib import Path
23
- from typing import Any, Dict, List, Optional
23
+ from typing import Annotated, Any, Optional
24
24
 
25
25
  import typer
26
- from typing_extensions import Annotated
27
26
 
28
27
  from flwr.cli.build import build
29
28
  from flwr.cli.config_utils import load_and_validate
@@ -52,7 +51,7 @@ def run(
52
51
  typer.Argument(help="Name of the federation to run the app on."),
53
52
  ] = None,
54
53
  config_overrides: Annotated[
55
- Optional[List[str]],
54
+ Optional[list[str]],
56
55
  typer.Option(
57
56
  "--run-config",
58
57
  "-c",
@@ -125,8 +124,8 @@ def run(
125
124
 
126
125
  def _run_with_superexec(
127
126
  app: Path,
128
- federation_config: Dict[str, Any],
129
- config_overrides: Optional[List[str]],
127
+ federation_config: dict[str, Any],
128
+ config_overrides: Optional[list[str]],
130
129
  ) -> None:
131
130
 
132
131
  insecure_str = federation_config.get("insecure")
@@ -187,8 +186,8 @@ def _run_with_superexec(
187
186
 
188
187
  def _run_without_superexec(
189
188
  app: Optional[Path],
190
- federation_config: Dict[str, Any],
191
- config_overrides: Optional[List[str]],
189
+ federation_config: dict[str, Any],
190
+ config_overrides: Optional[list[str]],
192
191
  federation: str,
193
192
  ) -> None:
194
193
  try:
flwr/cli/utils.py CHANGED
@@ -17,7 +17,7 @@
17
17
  import hashlib
18
18
  import re
19
19
  from pathlib import Path
20
- from typing import Callable, List, Optional, cast
20
+ from typing import Callable, Optional, cast
21
21
 
22
22
  import typer
23
23
 
@@ -40,7 +40,7 @@ def prompt_text(
40
40
  return cast(str, result)
41
41
 
42
42
 
43
- def prompt_options(text: str, options: List[str]) -> str:
43
+ def prompt_options(text: str, options: list[str]) -> str:
44
44
  """Ask user to select one of the given options and return the selected item."""
45
45
  # Turn options into a list with index as in " [ 0] quickstart-pytorch"
46
46
  options_formatted = [
flwr/client/app.py CHANGED
@@ -18,10 +18,11 @@ import signal
18
18
  import subprocess
19
19
  import sys
20
20
  import time
21
+ from contextlib import AbstractContextManager
21
22
  from dataclasses import dataclass
22
23
  from logging import ERROR, INFO, WARN
23
24
  from pathlib import Path
24
- from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union, cast
25
+ from typing import Callable, Optional, Union, cast
25
26
 
26
27
  import grpc
27
28
  from cryptography.hazmat.primitives.asymmetric import ec
@@ -35,6 +36,7 @@ from flwr.client.typing import ClientFnExt
35
36
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
36
37
  from flwr.common.address import parse_address
37
38
  from flwr.common.constant import (
39
+ CLIENTAPPIO_API_DEFAULT_ADDRESS,
38
40
  MISSING_EXTRA_REST,
39
41
  RUN_ID_NUM_BYTES,
40
42
  TRANSPORT_TYPE_GRPC_ADAPTER,
@@ -60,8 +62,6 @@ from .message_handler.message_handler import handle_control_message
60
62
  from .node_state import NodeState
61
63
  from .numpy_client import NumPyClient
62
64
 
63
- ADDRESS_CLIENTAPPIO_API_GRPC_RERE = "0.0.0.0:9094"
64
-
65
65
  ISOLATION_MODE_SUBPROCESS = "subprocess"
66
66
  ISOLATION_MODE_PROCESS = "process"
67
67
 
@@ -95,7 +95,7 @@ def start_client(
95
95
  insecure: Optional[bool] = None,
96
96
  transport: Optional[str] = None,
97
97
  authentication_keys: Optional[
98
- Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
98
+ tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
99
99
  ] = None,
100
100
  max_retries: Optional[int] = None,
101
101
  max_wait_time: Optional[float] = None,
@@ -205,13 +205,13 @@ def start_client_internal(
205
205
  insecure: Optional[bool] = None,
206
206
  transport: Optional[str] = None,
207
207
  authentication_keys: Optional[
208
- Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
208
+ tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
209
209
  ] = None,
210
210
  max_retries: Optional[int] = None,
211
211
  max_wait_time: Optional[float] = None,
212
212
  flwr_path: Optional[Path] = None,
213
213
  isolation: Optional[str] = None,
214
- supernode_address: Optional[str] = ADDRESS_CLIENTAPPIO_API_GRPC_RERE,
214
+ supernode_address: Optional[str] = CLIENTAPPIO_API_DEFAULT_ADDRESS,
215
215
  ) -> None:
216
216
  """Start a Flower client node which connects to a Flower server.
217
217
 
@@ -266,7 +266,7 @@ def start_client_internal(
266
266
  by the SueprNode and communicates using gRPC at the address
267
267
  `supernode_address`. If `process`, the `ClientApp` runs in a separate isolated
268
268
  process and communicates using gRPC at the address `supernode_address`.
269
- supernode_address : Optional[str] (default: `ADDRESS_CLIENTAPPIO_API_GRPC_RERE`)
269
+ supernode_address : Optional[str] (default: `CLIENTAPPIO_API_DEFAULT_ADDRESS`)
270
270
  The SuperNode gRPC server address.
271
271
  """
272
272
  if insecure is None:
@@ -357,7 +357,7 @@ def start_client_internal(
357
357
  # NodeState gets initialized when the first connection is established
358
358
  node_state: Optional[NodeState] = None
359
359
 
360
- runs: Dict[int, Run] = {}
360
+ runs: dict[int, Run] = {}
361
361
 
362
362
  while not app_state_tracker.interrupt:
363
363
  sleep_duration: int = 0
@@ -690,7 +690,7 @@ def start_numpy_client(
690
690
  )
691
691
 
692
692
 
693
- def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
693
+ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
694
694
  Callable[
695
695
  [
696
696
  str,
@@ -698,10 +698,10 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
698
698
  RetryInvoker,
699
699
  int,
700
700
  Union[bytes, str, None],
701
- Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]],
701
+ Optional[tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]],
702
702
  ],
703
- ContextManager[
704
- Tuple[
703
+ AbstractContextManager[
704
+ tuple[
705
705
  Callable[[], Optional[Message]],
706
706
  Callable[[Message], None],
707
707
  Optional[Callable[[], Optional[int]]],
@@ -712,7 +712,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
712
712
  ],
713
713
  ],
714
714
  str,
715
- Type[Exception],
715
+ type[Exception],
716
716
  ]:
717
717
  # Parse IP address
718
718
  parsed_address = parse_address(server_address)
@@ -770,7 +770,7 @@ class _AppStateTracker:
770
770
  signal.signal(signal.SIGTERM, signal_handler)
771
771
 
772
772
 
773
- def run_clientappio_api_grpc(address: str) -> Tuple[grpc.Server, ClientAppIoServicer]:
773
+ def run_clientappio_api_grpc(address: str) -> tuple[grpc.Server, ClientAppIoServicer]:
774
774
  """Run ClientAppIo API gRPC server."""
775
775
  clientappio_servicer: grpc.Server = ClientAppIoServicer()
776
776
  clientappio_add_servicer_to_server_fn = add_ClientAppIoServicer_to_server