flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (92) hide show
  1. flwr/cli/build.py +16 -2
  2. flwr/cli/config_utils.py +36 -14
  3. flwr/cli/install.py +17 -1
  4. flwr/cli/new/new.py +31 -20
  5. flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
  10. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
  12. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
  13. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
  14. flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
  15. flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
  16. flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
  17. flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
  18. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
  19. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
  20. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
  21. flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
  22. flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
  23. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
  25. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
  26. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
  27. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
  28. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
  29. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
  30. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
  31. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
  32. flwr/cli/run/run.py +128 -53
  33. flwr/client/app.py +56 -24
  34. flwr/client/client_app.py +28 -8
  35. flwr/client/grpc_adapter_client/connection.py +3 -2
  36. flwr/client/grpc_client/connection.py +3 -2
  37. flwr/client/grpc_rere_client/connection.py +17 -6
  38. flwr/client/message_handler/message_handler.py +1 -1
  39. flwr/client/node_state.py +59 -12
  40. flwr/client/node_state_tests.py +4 -3
  41. flwr/client/rest_client/connection.py +19 -8
  42. flwr/client/supernode/app.py +55 -24
  43. flwr/client/typing.py +2 -2
  44. flwr/common/config.py +87 -2
  45. flwr/common/constant.py +3 -0
  46. flwr/common/context.py +24 -9
  47. flwr/common/logger.py +25 -0
  48. flwr/common/serde.py +45 -0
  49. flwr/common/telemetry.py +17 -0
  50. flwr/common/typing.py +5 -0
  51. flwr/proto/common_pb2.py +36 -0
  52. flwr/proto/common_pb2.pyi +121 -0
  53. flwr/proto/common_pb2_grpc.py +4 -0
  54. flwr/proto/common_pb2_grpc.pyi +4 -0
  55. flwr/proto/driver_pb2.py +24 -19
  56. flwr/proto/driver_pb2.pyi +21 -1
  57. flwr/proto/exec_pb2.py +16 -11
  58. flwr/proto/exec_pb2.pyi +22 -1
  59. flwr/proto/run_pb2.py +12 -7
  60. flwr/proto/run_pb2.pyi +22 -1
  61. flwr/proto/task_pb2.py +7 -8
  62. flwr/server/__init__.py +2 -0
  63. flwr/server/compat/legacy_context.py +5 -4
  64. flwr/server/driver/grpc_driver.py +82 -140
  65. flwr/server/run_serverapp.py +40 -15
  66. flwr/server/server_app.py +56 -10
  67. flwr/server/serverapp_components.py +52 -0
  68. flwr/server/superlink/driver/driver_servicer.py +18 -3
  69. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  70. flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
  71. flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
  72. flwr/server/superlink/fleet/vce/vce_api.py +149 -117
  73. flwr/server/superlink/state/in_memory_state.py +11 -3
  74. flwr/server/superlink/state/sqlite_state.py +23 -8
  75. flwr/server/superlink/state/state.py +7 -2
  76. flwr/server/typing.py +2 -0
  77. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  78. flwr/simulation/app.py +4 -3
  79. flwr/simulation/ray_transport/ray_actor.py +15 -19
  80. flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
  81. flwr/simulation/run_simulation.py +237 -66
  82. flwr/superexec/app.py +14 -7
  83. flwr/superexec/deployment.py +110 -33
  84. flwr/superexec/exec_grpc.py +5 -1
  85. flwr/superexec/exec_servicer.py +4 -1
  86. flwr/superexec/executor.py +18 -0
  87. flwr/superexec/simulation.py +151 -0
  88. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
  89. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +92 -86
  90. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
  91. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
  92. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
flwr/cli/build.py CHANGED
@@ -20,6 +20,7 @@ from pathlib import Path
20
20
  from typing import Optional
21
21
 
22
22
  import pathspec
23
+ import tomli_w
23
24
  import typer
24
25
  from typing_extensions import Annotated
25
26
 
@@ -85,7 +86,7 @@ def build(
85
86
 
86
87
  # Set the name of the zip file
87
88
  fab_filename = (
88
- f"{conf['flower']['publisher']}"
89
+ f"{conf['tool']['flwr']['app']['publisher']}"
89
90
  f".{directory.name}"
90
91
  f".{conf['project']['version'].replace('.', '-')}.fab"
91
92
  )
@@ -93,15 +94,28 @@ def build(
93
94
 
94
95
  allowed_extensions = {".py", ".toml", ".md"}
95
96
 
97
+ # Remove the 'federations' field from 'tool.flwr' if it exists
98
+ if (
99
+ "tool" in conf
100
+ and "flwr" in conf["tool"]
101
+ and "federations" in conf["tool"]["flwr"]
102
+ ):
103
+ del conf["tool"]["flwr"]["federations"]
104
+
105
+ toml_contents = tomli_w.dumps(conf)
106
+
96
107
  with zipfile.ZipFile(fab_filename, "w", zipfile.ZIP_DEFLATED) as fab_file:
108
+ fab_file.writestr("pyproject.toml", toml_contents)
109
+
110
+ # Continue with adding other files
97
111
  for root, _, files in os.walk(directory, topdown=True):
98
- # Filter directories and files based on .gitignore
99
112
  files = [
100
113
  f
101
114
  for f in files
102
115
  if not ignore_spec.match_file(Path(root) / f)
103
116
  and f != fab_filename
104
117
  and Path(f).suffix in allowed_extensions
118
+ and f != "pyproject.toml" # Exclude the original pyproject.toml
105
119
  ]
106
120
 
107
121
  for file in files:
flwr/cli/config_utils.py CHANGED
@@ -17,11 +17,12 @@
17
17
  import zipfile
18
18
  from io import BytesIO
19
19
  from pathlib import Path
20
- from typing import IO, Any, Dict, List, Optional, Tuple, Union
20
+ from typing import IO, Any, Dict, List, Optional, Tuple, Union, get_args
21
21
 
22
22
  import tomli
23
23
 
24
24
  from flwr.common import object_ref
25
+ from flwr.common.typing import UserConfigValue
25
26
 
26
27
 
27
28
  def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
@@ -60,7 +61,7 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
60
61
 
61
62
  return (
62
63
  conf["project"]["version"],
63
- f"{conf['flower']['publisher']}/{conf['project']['name']}",
64
+ f"{conf['tool']['flwr']['app']['publisher']}/{conf['project']['name']}",
64
65
  )
65
66
 
66
67
 
@@ -108,6 +109,17 @@ def load(path: Optional[Path] = None) -> Optional[Dict[str, Any]]:
108
109
  return load_from_string(toml_file.read())
109
110
 
110
111
 
112
+ def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None:
113
+ for key, value in config_dict.items():
114
+ if isinstance(value, dict):
115
+ _validate_run_config(config_dict[key], errors)
116
+ elif not isinstance(value, get_args(UserConfigValue)):
117
+ raise ValueError(
118
+ f"The value for key {key} needs to be of type `int`, `float`, "
119
+ "`bool, `str`, or a `dict` of those.",
120
+ )
121
+
122
+
111
123
  # pylint: disable=too-many-branches
112
124
  def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
113
125
  """Validate pyproject.toml fields."""
@@ -128,18 +140,28 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]
128
140
  if "authors" not in config["project"]:
129
141
  warnings.append('Recommended property "authors" missing in [project]')
130
142
 
131
- if "flower" not in config:
132
- errors.append("Missing [flower] section")
143
+ if (
144
+ "tool" not in config
145
+ or "flwr" not in config["tool"]
146
+ or "app" not in config["tool"]["flwr"]
147
+ ):
148
+ errors.append("Missing [tool.flwr.app] section")
133
149
  else:
134
- if "publisher" not in config["flower"]:
135
- errors.append('Property "publisher" missing in [flower]')
136
- if "components" not in config["flower"]:
137
- errors.append("Missing [flower.components] section")
150
+ if "publisher" not in config["tool"]["flwr"]["app"]:
151
+ errors.append('Property "publisher" missing in [tool.flwr.app]')
152
+ if "config" in config["tool"]["flwr"]["app"]:
153
+ _validate_run_config(config["tool"]["flwr"]["app"]["config"], errors)
154
+ if "components" not in config["tool"]["flwr"]["app"]:
155
+ errors.append("Missing [tool.flwr.app.components] section")
138
156
  else:
139
- if "serverapp" not in config["flower"]["components"]:
140
- errors.append('Property "serverapp" missing in [flower.components]')
141
- if "clientapp" not in config["flower"]["components"]:
142
- errors.append('Property "clientapp" missing in [flower.components]')
157
+ if "serverapp" not in config["tool"]["flwr"]["app"]["components"]:
158
+ errors.append(
159
+ 'Property "serverapp" missing in [tool.flwr.app.components]'
160
+ )
161
+ if "clientapp" not in config["tool"]["flwr"]["app"]["components"]:
162
+ errors.append(
163
+ 'Property "clientapp" missing in [tool.flwr.app.components]'
164
+ )
143
165
 
144
166
  return len(errors) == 0, errors, warnings
145
167
 
@@ -155,14 +177,14 @@ def validate(
155
177
 
156
178
  # Validate serverapp
157
179
  is_valid, reason = object_ref.validate(
158
- config["flower"]["components"]["serverapp"], check_module
180
+ config["tool"]["flwr"]["app"]["components"]["serverapp"], check_module
159
181
  )
160
182
  if not is_valid and isinstance(reason, str):
161
183
  return False, [reason], []
162
184
 
163
185
  # Validate clientapp
164
186
  is_valid, reason = object_ref.validate(
165
- config["flower"]["components"]["clientapp"], check_module
187
+ config["tool"]["flwr"]["app"]["components"]["clientapp"], check_module
166
188
  )
167
189
 
168
190
  if not is_valid and isinstance(reason, str):
flwr/cli/install.py CHANGED
@@ -16,6 +16,7 @@
16
16
 
17
17
 
18
18
  import shutil
19
+ import subprocess
19
20
  import tempfile
20
21
  import zipfile
21
22
  from io import BytesIO
@@ -149,7 +150,7 @@ def validate_and_install(
149
150
  )
150
151
  raise typer.Exit(code=1)
151
152
 
152
- publisher = config["flower"]["publisher"]
153
+ publisher = config["tool"]["flwr"]["app"]["publisher"]
153
154
  project_name = config["project"]["name"]
154
155
  version = config["project"]["version"]
155
156
 
@@ -192,6 +193,21 @@ def validate_and_install(
192
193
  else:
193
194
  shutil.copy2(item, install_dir / item.name)
194
195
 
196
+ try:
197
+ subprocess.run(
198
+ ["pip", "install", "-e", install_dir, "--no-deps"],
199
+ capture_output=True,
200
+ text=True,
201
+ check=True,
202
+ )
203
+ except subprocess.CalledProcessError as e:
204
+ typer.secho(
205
+ f"❌ Failed to `pip install` package(s) from {install_dir}:\n{e.stderr}",
206
+ fg=typer.colors.RED,
207
+ bold=True,
208
+ )
209
+ raise typer.Exit(code=1) from e
210
+
195
211
  typer.secho(
196
212
  f"🎊 Successfully installed {project_name} to {install_dir}.",
197
213
  fg=typer.colors.GREEN,
flwr/cli/new/new.py CHANGED
@@ -14,9 +14,9 @@
14
14
  # ==============================================================================
15
15
  """Flower command line interface `new` command."""
16
16
 
17
- import os
18
17
  import re
19
18
  from enum import Enum
19
+ from pathlib import Path
20
20
  from string import Template
21
21
  from typing import Dict, Optional
22
22
 
@@ -59,10 +59,10 @@ class TemplateNotFound(Exception):
59
59
 
60
60
  def load_template(name: str) -> str:
61
61
  """Load template from template directory and return as text."""
62
- tpl_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "templates"))
63
- tpl_file_path = os.path.join(tpl_dir, name)
62
+ tpl_dir = (Path(__file__).parent / "templates").absolute()
63
+ tpl_file_path = tpl_dir / name
64
64
 
65
- if not os.path.isfile(tpl_file_path):
65
+ if not tpl_file_path.is_file():
66
66
  raise TemplateNotFound(f"Template '{name}' not found")
67
67
 
68
68
  with open(tpl_file_path, encoding="utf-8") as tpl_file:
@@ -78,14 +78,13 @@ def render_template(template: str, data: Dict[str, str]) -> str:
78
78
  return tpl.template
79
79
 
80
80
 
81
- def create_file(file_path: str, content: str) -> None:
81
+ def create_file(file_path: Path, content: str) -> None:
82
82
  """Create file including all nessecary directories and write content into file."""
83
- os.makedirs(os.path.dirname(file_path), exist_ok=True)
84
- with open(file_path, "w", encoding="utf-8") as f:
85
- f.write(content)
83
+ file_path.parent.mkdir(exist_ok=True)
84
+ file_path.write_text(content)
86
85
 
87
86
 
88
- def render_and_create(file_path: str, template: str, context: Dict[str, str]) -> None:
87
+ def render_and_create(file_path: Path, template: str, context: Dict[str, str]) -> None:
89
88
  """Render template and write to file."""
90
89
  content = render_template(template, context)
91
90
  create_file(file_path, content)
@@ -117,6 +116,21 @@ def new(
117
116
  default=sanitize_project_name(project_name),
118
117
  )
119
118
 
119
+ # Set project directory path
120
+ package_name = re.sub(r"[-_.]+", "-", project_name).lower()
121
+ import_name = package_name.replace("-", "_")
122
+ project_dir = Path.cwd() / package_name
123
+
124
+ if project_dir.exists():
125
+ if not typer.confirm(
126
+ typer.style(
127
+ f"\n💬 {project_name} already exists, do you want to override it?",
128
+ fg=typer.colors.MAGENTA,
129
+ bold=True,
130
+ )
131
+ ):
132
+ return
133
+
120
134
  if username is None:
121
135
  username = prompt_text("Please provide your Flower username")
122
136
 
@@ -136,6 +150,7 @@ def new(
136
150
 
137
151
  framework_str = framework_str.lower()
138
152
 
153
+ llm_challenge_str = None
139
154
  if framework_str == "flowertune":
140
155
  llm_challenge_value = prompt_options(
141
156
  "Please select LLM challenge by typing in the number",
@@ -157,12 +172,6 @@ def new(
157
172
  )
158
173
  )
159
174
 
160
- # Set project directory path
161
- cwd = os.getcwd()
162
- package_name = re.sub(r"[-_.]+", "-", project_name).lower()
163
- import_name = package_name.replace("-", "_")
164
- project_dir = os.path.join(cwd, package_name)
165
-
166
175
  context = {
167
176
  "project_name": project_name,
168
177
  "package_name": package_name,
@@ -171,7 +180,7 @@ def new(
171
180
  }
172
181
 
173
182
  # List of files to render
174
- if framework_str == "flowertune":
183
+ if llm_challenge_str:
175
184
  files = {
176
185
  ".gitignore": {"template": "app/.gitignore.tpl"},
177
186
  "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
@@ -228,10 +237,10 @@ def new(
228
237
  "README.md": {"template": "app/README.md.tpl"},
229
238
  "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
230
239
  f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
231
- f"{import_name}/server.py": {
240
+ f"{import_name}/server_app.py": {
232
241
  "template": f"app/code/server.{framework_str}.py.tpl"
233
242
  },
234
- f"{import_name}/client.py": {
243
+ f"{import_name}/client_app.py": {
235
244
  "template": f"app/code/client.{framework_str}.py.tpl"
236
245
  },
237
246
  }
@@ -251,7 +260,7 @@ def new(
251
260
 
252
261
  for file_path, value in files.items():
253
262
  render_and_create(
254
- file_path=os.path.join(project_dir, file_path),
263
+ file_path=project_dir / file_path,
255
264
  template=value["template"],
256
265
  context=context,
257
266
  )
@@ -264,9 +273,11 @@ def new(
264
273
  bold=True,
265
274
  )
266
275
  )
276
+
277
+ _add = " huggingface-cli login\n" if framework_str == "flowertune" else ""
267
278
  print(
268
279
  typer.style(
269
- f" cd {package_name}\n" + " pip install -e .\n flwr run\n",
280
+ f" cd {package_name}\n" + " pip install -e .\n" + _add + " flwr run\n",
270
281
  fg=typer.colors.BRIGHT_CYAN,
271
282
  bold=True,
272
283
  )
@@ -1,6 +1,7 @@
1
1
  """$project_name: A Flower / HuggingFace Transformers app."""
2
2
 
3
3
  from flwr.client import ClientApp, NumPyClient
4
+ from flwr.common import Context
4
5
  from transformers import AutoModelForSequenceClassification
5
6
 
6
7
  from $import_name.task import (
@@ -29,7 +30,11 @@ class FlowerClient(NumPyClient):
29
30
 
30
31
  def fit(self, parameters, config):
31
32
  self.set_parameters(parameters)
32
- train(self.net, self.trainloader, epochs=1)
33
+ train(
34
+ self.net,
35
+ self.trainloader,
36
+ epochs=int(self.context.run_config["local-epochs"]),
37
+ )
33
38
  return self.get_parameters(config={}), len(self.trainloader), {}
34
39
 
35
40
  def evaluate(self, parameters, config):
@@ -38,12 +43,15 @@ class FlowerClient(NumPyClient):
38
43
  return float(loss), len(self.testloader), {"accuracy": accuracy}
39
44
 
40
45
 
41
- def client_fn(cid):
46
+ def client_fn(context: Context):
42
47
  # Load model and data
43
48
  net = AutoModelForSequenceClassification.from_pretrained(
44
49
  CHECKPOINT, num_labels=2
45
50
  ).to(DEVICE)
46
- trainloader, valloader = load_data(int(cid), 2)
51
+
52
+ partition_id = int(context.node_config["partition-id"])
53
+ num_partitions = int(context.node_config["num-partitions"])
54
+ trainloader, valloader = load_data(partition_id, num_partitions)
47
55
 
48
56
  # Return Client instance
49
57
  return FlowerClient(net, trainloader, valloader).to_client()
@@ -2,6 +2,7 @@
2
2
 
3
3
  import jax
4
4
  from flwr.client import NumPyClient, ClientApp
5
+ from flwr.common import Context
5
6
 
6
7
  from $import_name.task import (
7
8
  evaluation,
@@ -44,7 +45,7 @@ class FlowerClient(NumPyClient):
44
45
  )
45
46
  return float(loss), num_examples, {"loss": float(loss)}
46
47
 
47
- def client_fn(cid):
48
+ def client_fn(context: Context):
48
49
  # Return Client instance
49
50
  return FlowerClient().to_client()
50
51
 
@@ -4,6 +4,7 @@ import mlx.core as mx
4
4
  import mlx.nn as nn
5
5
  import mlx.optimizers as optim
6
6
  from flwr.client import NumPyClient, ClientApp
7
+ from flwr.common import Context
7
8
 
8
9
  from $import_name.task import (
9
10
  batch_iterate,
@@ -19,17 +20,19 @@ from $import_name.task import (
19
20
  # Define Flower Client and client_fn
20
21
  class FlowerClient(NumPyClient):
21
22
  def __init__(self, data):
22
- num_layers = 2
23
- hidden_dim = 32
23
+ num_layers = int(self.context.run_config["num-layers"])
24
+ hidden_dim = int(self.context.run_config["hidden-dim"])
24
25
  num_classes = 10
25
- batch_size = 256
26
- num_epochs = 1
27
- learning_rate = 1e-1
26
+ batch_size = int(self.context.run_config["batch-size"])
27
+ learning_rate = float(self.context.run_config["lr"])
28
+ num_epochs = int(self.context.run_config["local-epochs"])
28
29
 
29
30
  self.train_images, self.train_labels, self.test_images, self.test_labels = data
30
- self.model = MLP(num_layers, self.train_images.shape[-1], hidden_dim, num_classes)
31
- self.optimizer = optim.SGD(learning_rate=learning_rate)
32
- self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
31
+ self.model = MLP(
32
+ num_layers, self.train_images.shape[-1], hidden_dim, num_classes
33
+ )
34
+ self.optimizer = optim.SGD(learning_rate=learning_rate)
35
+ self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
33
36
  self.num_epochs = num_epochs
34
37
  self.batch_size = batch_size
35
38
 
@@ -57,8 +60,10 @@ class FlowerClient(NumPyClient):
57
60
  return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}
58
61
 
59
62
 
60
- def client_fn(cid):
61
- data = load_data(int(cid), 2)
63
+ def client_fn(context: Context):
64
+ partition_id = int(context.node_config["partition-id"])
65
+ num_partitions = int(context.node_config["num-partitions"])
66
+ data = load_data(partition_id, num_partitions)
62
67
 
63
68
  # Return Client instance
64
69
  return FlowerClient(data).to_client()
@@ -1,6 +1,7 @@
1
1
  """$project_name: A Flower / NumPy app."""
2
2
 
3
3
  from flwr.client import NumPyClient, ClientApp
4
+ from flwr.common import Context
4
5
  import numpy as np
5
6
 
6
7
 
@@ -15,7 +16,7 @@ class FlowerClient(NumPyClient):
15
16
  return float(0.0), 1, {"accuracy": float(1.0)}
16
17
 
17
18
 
18
- def client_fn(cid: str):
19
+ def client_fn(context: Context):
19
20
  return FlowerClient().to_client()
20
21
 
21
22
 
@@ -1,6 +1,7 @@
1
1
  """$project_name: A Flower / PyTorch app."""
2
2
 
3
3
  from flwr.client import NumPyClient, ClientApp
4
+ from flwr.common import Context
4
5
 
5
6
  from $import_name.task import (
6
7
  Net,
@@ -22,7 +23,13 @@ class FlowerClient(NumPyClient):
22
23
 
23
24
  def fit(self, parameters, config):
24
25
  set_weights(self.net, parameters)
25
- results = train(self.net, self.trainloader, self.valloader, 1, DEVICE)
26
+ results = train(
27
+ self.net,
28
+ self.trainloader,
29
+ self.valloader,
30
+ int(self.context.run_config["local-epochs"]),
31
+ DEVICE,
32
+ )
26
33
  return get_weights(self.net), len(self.trainloader.dataset), results
27
34
 
28
35
  def evaluate(self, parameters, config):
@@ -31,10 +38,12 @@ class FlowerClient(NumPyClient):
31
38
  return loss, len(self.valloader.dataset), {"accuracy": accuracy}
32
39
 
33
40
 
34
- def client_fn(cid):
41
+ def client_fn(context: Context):
35
42
  # Load model and data
36
43
  net = Net().to(DEVICE)
37
- trainloader, valloader = load_data(int(cid), 2)
44
+ partition_id = int(context.node_config["partition-id"])
45
+ num_partitions = int(context.node_config["num-partitions"])
46
+ trainloader, valloader = load_data(partition_id, num_partitions)
38
47
 
39
48
  # Return Client instance
40
49
  return FlowerClient(net, trainloader, valloader).to_client()
@@ -4,6 +4,7 @@ import warnings
4
4
 
5
5
  import numpy as np
6
6
  from flwr.client import NumPyClient, ClientApp
7
+ from flwr.common import Context
7
8
  from flwr_datasets import FederatedDataset
8
9
  from sklearn.linear_model import LogisticRegression
9
10
  from sklearn.metrics import log_loss
@@ -66,10 +67,12 @@ class FlowerClient(NumPyClient):
66
67
 
67
68
  return loss, len(self.X_test), {"accuracy": accuracy}
68
69
 
69
- fds = FederatedDataset(dataset="mnist", partitioners={"train": 2})
70
70
 
71
- def client_fn(cid: str):
72
- dataset = fds.load_partition(int(cid), "train").with_format("numpy")
71
+ def client_fn(context: Context):
72
+ partition_id = int(context.node_config["partition-id"])
73
+ num_partitions = int(context.node_config["num-partitions"])
74
+ fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
75
+ dataset = fds.load_partition(partition_id, "train").with_format("numpy")
73
76
 
74
77
  X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
75
78
 
@@ -1,6 +1,7 @@
1
1
  """$project_name: A Flower / TensorFlow app."""
2
2
 
3
3
  from flwr.client import NumPyClient, ClientApp
4
+ from flwr.common import Context
4
5
 
5
6
  from $import_name.task import load_data, load_model
6
7
 
@@ -19,7 +20,13 @@ class FlowerClient(NumPyClient):
19
20
 
20
21
  def fit(self, parameters, config):
21
22
  self.model.set_weights(parameters)
22
- self.model.fit(self.x_train, self.y_train, epochs=1, batch_size=32, verbose=0)
23
+ self.model.fit(
24
+ self.x_train,
25
+ self.y_train,
26
+ epochs=int(self.context.run_config["local-epochs"]),
27
+ batch_size=int(self.context.run_config["batch-size"]),
28
+ verbose=bool(self.context.run_config.get("verbose")),
29
+ )
23
30
  return self.model.get_weights(), len(self.x_train), {}
24
31
 
25
32
  def evaluate(self, parameters, config):
@@ -28,10 +35,13 @@ class FlowerClient(NumPyClient):
28
35
  return loss, len(self.x_test), {"accuracy": accuracy}
29
36
 
30
37
 
31
- def client_fn(cid):
38
+ def client_fn(context: Context):
32
39
  # Load model and data
33
40
  net = load_model()
34
- x_train, y_train, x_test, y_test = load_data(int(cid), 2)
41
+
42
+ partition_id = int(context.node_config["partition-id"])
43
+ num_partitions = int(context.node_config["num-partitions"])
44
+ x_train, y_train, x_test, y_test = load_data(partition_id, num_partitions)
35
45
 
36
46
  # Return Client instance
37
47
  return FlowerClient(net, x_train, y_train, x_test, y_test).to_client()
@@ -12,10 +12,10 @@ from flwr.client import ClientApp
12
12
  from flwr.common import ndarrays_to_parameters
13
13
  from flwr.server import ServerApp, ServerConfig
14
14
 
15
- from $import_name.client import gen_client_fn, get_parameters
15
+ from $import_name.client_app import gen_client_fn, get_parameters
16
16
  from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
17
17
  from $import_name.models import get_model
18
- from $import_name.server import fit_weighted_average, get_evaluate_fn, get_on_fit_config
18
+ from $import_name.server_app import fit_weighted_average, get_evaluate_fn, get_on_fit_config
19
19
 
20
20
  # Avoid warnings
21
21
  warnings.filterwarnings("ignore", category=UserWarning)
@@ -1,6 +1,6 @@
1
1
  """$project_name: A Flower / FlowerTune app."""
2
2
 
3
- from $import_name.client import set_parameters
3
+ from $import_name.client_app import set_parameters
4
4
  from $import_name.models import get_model
5
5
 
6
6
 
@@ -1,17 +1,22 @@
1
1
  """$project_name: A Flower / HuggingFace Transformers app."""
2
2
 
3
+ from flwr.common import Context
3
4
  from flwr.server.strategy import FedAvg
4
- from flwr.server import ServerApp, ServerConfig
5
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
5
6
 
6
7
 
7
- # Define strategy
8
- strategy = FedAvg(
9
- fraction_fit=1.0,
10
- fraction_evaluate=1.0,
11
- )
8
+ def server_fn(context: Context):
9
+ # Read from config
10
+ num_rounds = int(context.run_config["num-server-rounds"])
12
11
 
13
- # Start server
14
- app = ServerApp(
15
- config=ServerConfig(num_rounds=3),
16
- strategy=strategy,
17
- )
12
+ # Define strategy
13
+ strategy = FedAvg(
14
+ fraction_fit=1.0,
15
+ fraction_evaluate=1.0,
16
+ )
17
+ config = ServerConfig(num_rounds=num_rounds)
18
+
19
+ return ServerAppComponents(strategy=strategy, config=config)
20
+
21
+ # Create ServerApp
22
+ app = ServerApp(server_fn=server_fn)
@@ -1,12 +1,19 @@
1
1
  """$project_name: A Flower / JAX app."""
2
2
 
3
- import flwr as fl
3
+ from flwr.common import Context
4
+ from flwr.server.strategy import FedAvg
5
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
4
6
 
5
- # Configure the strategy
6
- strategy = fl.server.strategy.FedAvg()
7
7
 
8
- # Flower ServerApp
9
- app = fl.server.ServerApp(
10
- config=fl.server.ServerConfig(num_rounds=3),
11
- strategy=strategy,
12
- )
8
+ def server_fn(context: Context):
9
+ # Read from config
10
+ num_rounds = int(context.run_config["num-server-rounds"])
11
+
12
+ # Define strategy
13
+ strategy = FedAvg()
14
+ config = ServerConfig(num_rounds=num_rounds)
15
+
16
+ return ServerAppComponents(strategy=strategy, config=config)
17
+
18
+ # Create ServerApp
19
+ app = ServerApp(server_fn=server_fn)
@@ -1,15 +1,19 @@
1
1
  """$project_name: A Flower / MLX app."""
2
2
 
3
- from flwr.server import ServerApp, ServerConfig
3
+ from flwr.common import Context
4
+ from flwr.server import ServerApp, ServerAppComponents, ServerConfig
4
5
  from flwr.server.strategy import FedAvg
5
6
 
6
7
 
7
- # Define strategy
8
- strategy = FedAvg()
8
+ def server_fn(context: Context):
9
+ # Read from config
10
+ num_rounds = int(context.run_config["num-server-rounds"])
9
11
 
12
+ # Define strategy
13
+ strategy = FedAvg()
14
+ config = ServerConfig(num_rounds=num_rounds)
15
+
16
+ return ServerAppComponents(strategy=strategy, config=config)
10
17
 
11
18
  # Create ServerApp
12
- app = ServerApp(
13
- config=ServerConfig(num_rounds=3),
14
- strategy=strategy,
15
- )
19
+ app = ServerApp(server_fn=server_fn)