flwr-nightly 1.11.0.dev20240724__py3-none-any.whl → 1.11.0.dev20240811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (59) hide show
  1. flwr/cli/build.py +22 -20
  2. flwr/cli/config_utils.py +27 -8
  3. flwr/cli/new/new.py +23 -22
  4. flwr/cli/new/templates/app/README.md.tpl +1 -1
  5. flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
  6. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +1 -1
  7. flwr/cli/new/templates/app/code/client.jax.py.tpl +1 -1
  8. flwr/cli/new/templates/app/code/client.mlx.py.tpl +1 -1
  9. flwr/cli/new/templates/app/code/client.numpy.py.tpl +1 -1
  10. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +9 -8
  11. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +1 -1
  12. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +5 -8
  13. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +1 -1
  14. flwr/cli/new/templates/app/code/server.jax.py.tpl +1 -1
  15. flwr/cli/new/templates/app/code/server.mlx.py.tpl +1 -1
  16. flwr/cli/new/templates/app/code/server.numpy.py.tpl +1 -1
  17. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +7 -6
  18. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +1 -1
  19. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +4 -5
  20. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +1 -1
  21. flwr/cli/new/templates/app/code/task.jax.py.tpl +2 -2
  22. flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -1
  23. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -20
  24. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +16 -4
  25. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +2 -2
  26. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  27. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +3 -3
  28. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  29. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -2
  30. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +2 -2
  31. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  32. flwr/cli/run/run.py +15 -12
  33. flwr/client/grpc_rere_client/grpc_adapter.py +7 -0
  34. flwr/client/supernode/app.py +36 -28
  35. flwr/common/config.py +30 -0
  36. flwr/common/typing.py +8 -0
  37. flwr/proto/driver_pb2.py +22 -21
  38. flwr/proto/driver_pb2.pyi +7 -1
  39. flwr/proto/driver_pb2_grpc.py +35 -0
  40. flwr/proto/driver_pb2_grpc.pyi +14 -0
  41. flwr/proto/fab_pb2.py +6 -6
  42. flwr/proto/fab_pb2.pyi +8 -8
  43. flwr/proto/fleet_pb2.py +28 -27
  44. flwr/proto/fleet_pb2_grpc.py +35 -0
  45. flwr/proto/fleet_pb2_grpc.pyi +14 -0
  46. flwr/proto/run_pb2.py +8 -8
  47. flwr/proto/run_pb2.pyi +4 -1
  48. flwr/server/run_serverapp.py +28 -46
  49. flwr/server/superlink/driver/driver_servicer.py +7 -0
  50. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +7 -0
  51. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  52. flwr/server/superlink/fleet/vce/backend/raybackend.py +4 -35
  53. flwr/server/superlink/fleet/vce/vce_api.py +3 -3
  54. flwr/superexec/simulation.py +15 -3
  55. {flwr_nightly-1.11.0.dev20240724.dist-info → flwr_nightly-1.11.0.dev20240811.dist-info}/METADATA +2 -2
  56. {flwr_nightly-1.11.0.dev20240724.dist-info → flwr_nightly-1.11.0.dev20240811.dist-info}/RECORD +59 -59
  57. {flwr_nightly-1.11.0.dev20240724.dist-info → flwr_nightly-1.11.0.dev20240811.dist-info}/LICENSE +0 -0
  58. {flwr_nightly-1.11.0.dev20240724.dist-info → flwr_nightly-1.11.0.dev20240811.dist-info}/WHEEL +0 -0
  59. {flwr_nightly-1.11.0.dev20240724.dist-info → flwr_nightly-1.11.0.dev20240811.dist-info}/entry_points.txt +0 -0
flwr/cli/build.py CHANGED
@@ -30,32 +30,34 @@ from .utils import get_sha256_hash, is_valid_project_name
30
30
 
31
31
  # pylint: disable=too-many-locals
32
32
  def build(
33
- directory: Annotated[
33
+ app: Annotated[
34
34
  Optional[Path],
35
- typer.Option(help="Path of the Flower project to bundle into a FAB"),
35
+ typer.Option(help="Path of the Flower App to bundle into a FAB"),
36
36
  ] = None,
37
37
  ) -> str:
38
- """Build a Flower project into a Flower App Bundle (FAB).
38
+ """Build a Flower App into a Flower App Bundle (FAB).
39
39
 
40
- You can run ``flwr build`` without any arguments to bundle the current directory,
41
- or you can use ``--directory`` to build a specific directory:
42
- ``flwr build --directory ./projects/flower-hello-world``.
40
+ You can run ``flwr build`` without any arguments to bundle the app located in the
41
+ current directory. Alternatively, you can you can specify a path using the ``--app``
42
+ option to bundle an app located at the provided path. For example:
43
+
44
+ ``flwr build --app ./apps/flower-hello-world``.
43
45
  """
44
- if directory is None:
45
- directory = Path.cwd()
46
+ if app is None:
47
+ app = Path.cwd()
46
48
 
47
- directory = directory.resolve()
48
- if not directory.is_dir():
49
+ app = app.resolve()
50
+ if not app.is_dir():
49
51
  typer.secho(
50
- f"❌ The path {directory} is not a valid directory.",
52
+ f"❌ The path {app} is not a valid path to a Flower app.",
51
53
  fg=typer.colors.RED,
52
54
  bold=True,
53
55
  )
54
56
  raise typer.Exit(code=1)
55
57
 
56
- if not is_valid_project_name(directory.name):
58
+ if not is_valid_project_name(app.name):
57
59
  typer.secho(
58
- f"❌ The project name {directory.name} is invalid, "
60
+ f"❌ The project name {app.name} is invalid, "
59
61
  "a valid project name must start with a letter or an underscore, "
60
62
  "and can only contain letters, digits, and underscores.",
61
63
  fg=typer.colors.RED,
@@ -63,7 +65,7 @@ def build(
63
65
  )
64
66
  raise typer.Exit(code=1)
65
67
 
66
- conf, errors, warnings = load_and_validate(directory / "pyproject.toml")
68
+ conf, errors, warnings = load_and_validate(app / "pyproject.toml")
67
69
  if conf is None:
68
70
  typer.secho(
69
71
  "Project configuration could not be loaded.\npyproject.toml is invalid:\n"
@@ -82,12 +84,12 @@ def build(
82
84
  )
83
85
 
84
86
  # Load .gitignore rules if present
85
- ignore_spec = _load_gitignore(directory)
87
+ ignore_spec = _load_gitignore(app)
86
88
 
87
89
  # Set the name of the zip file
88
90
  fab_filename = (
89
91
  f"{conf['tool']['flwr']['app']['publisher']}"
90
- f".{directory.name}"
92
+ f".{app.name}"
91
93
  f".{conf['project']['version'].replace('.', '-')}.fab"
92
94
  )
93
95
  list_file_content = ""
@@ -108,7 +110,7 @@ def build(
108
110
  fab_file.writestr("pyproject.toml", toml_contents)
109
111
 
110
112
  # Continue with adding other files
111
- for root, _, files in os.walk(directory, topdown=True):
113
+ for root, _, files in os.walk(app, topdown=True):
112
114
  files = [
113
115
  f
114
116
  for f in files
@@ -120,7 +122,7 @@ def build(
120
122
 
121
123
  for file in files:
122
124
  file_path = Path(root) / file
123
- archive_path = file_path.relative_to(directory)
125
+ archive_path = file_path.relative_to(app)
124
126
  fab_file.write(file_path, archive_path)
125
127
 
126
128
  # Calculate file info
@@ -138,9 +140,9 @@ def build(
138
140
  return fab_filename
139
141
 
140
142
 
141
- def _load_gitignore(directory: Path) -> pathspec.PathSpec:
143
+ def _load_gitignore(app: Path) -> pathspec.PathSpec:
142
144
  """Load and parse .gitignore file, returning a pathspec."""
143
- gitignore_path = directory / ".gitignore"
145
+ gitignore_path = app / ".gitignore"
144
146
  patterns = ["__pycache__/"] # Default pattern
145
147
  if gitignore_path.exists():
146
148
  with open(gitignore_path, encoding="UTF-8") as file:
flwr/cli/config_utils.py CHANGED
@@ -25,8 +25,8 @@ from flwr.common import object_ref
25
25
  from flwr.common.typing import UserConfigValue
26
26
 
27
27
 
28
- def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
29
- """Extract the fab_id and the fab_version from a FAB file or path.
28
+ def get_fab_config(fab_file: Union[Path, bytes]) -> Dict[str, Any]:
29
+ """Extract the config from a FAB file or path.
30
30
 
31
31
  Parameters
32
32
  ----------
@@ -36,8 +36,8 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
36
36
 
37
37
  Returns
38
38
  -------
39
- Tuple[str, str]
40
- The `fab_version` and `fab_id` of the given Flower App Bundle.
39
+ Dict[str, Any]
40
+ The `config` of the given Flower App Bundle.
41
41
  """
42
42
  fab_file_archive: Union[Path, IO[bytes]]
43
43
  if isinstance(fab_file, bytes):
@@ -59,10 +59,29 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
59
59
  if not is_valid:
60
60
  raise ValueError(errors)
61
61
 
62
- return (
63
- conf["project"]["version"],
64
- f"{conf['tool']['flwr']['app']['publisher']}/{conf['project']['name']}",
65
- )
62
+ return conf
63
+
64
+
65
+ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
66
+ """Extract the fab_id and the fab_version from a FAB file or path.
67
+
68
+ Parameters
69
+ ----------
70
+ fab_file : Union[Path, bytes]
71
+ The Flower App Bundle file to validate and extract the metadata from.
72
+ It can either be a path to the file or the file itself as bytes.
73
+
74
+ Returns
75
+ -------
76
+ Tuple[str, str]
77
+ The `fab_version` and `fab_id` of the given Flower App Bundle.
78
+ """
79
+ conf = get_fab_config(fab_file)
80
+
81
+ return (
82
+ conf["project"]["version"],
83
+ f"{conf['tool']['flwr']['app']['publisher']}/{conf['project']['name']}",
84
+ )
66
85
 
67
86
 
68
87
  def load_and_validate(
flwr/cli/new/new.py CHANGED
@@ -34,13 +34,13 @@ from ..utils import (
34
34
  class MlFramework(str, Enum):
35
35
  """Available frameworks."""
36
36
 
37
- NUMPY = "NumPy"
38
37
  PYTORCH = "PyTorch"
39
38
  TENSORFLOW = "TensorFlow"
40
- JAX = "JAX"
39
+ SKLEARN = "sklearn"
41
40
  HUGGINGFACE = "HuggingFace"
41
+ JAX = "JAX"
42
42
  MLX = "MLX"
43
- SKLEARN = "sklearn"
43
+ NUMPY = "NumPy"
44
44
  FLOWERTUNE = "FlowerTune"
45
45
 
46
46
 
@@ -92,9 +92,9 @@ def render_and_create(file_path: Path, template: str, context: Dict[str, str]) -
92
92
 
93
93
  # pylint: disable=too-many-locals,too-many-branches,too-many-statements
94
94
  def new(
95
- project_name: Annotated[
95
+ app_name: Annotated[
96
96
  Optional[str],
97
- typer.Argument(metavar="project_name", help="The name of the project"),
97
+ typer.Argument(help="The name of the Flower App"),
98
98
  ] = None,
99
99
  framework: Annotated[
100
100
  Optional[MlFramework],
@@ -105,26 +105,26 @@ def new(
105
105
  typer.Option(case_sensitive=False, help="The Flower username of the author"),
106
106
  ] = None,
107
107
  ) -> None:
108
- """Create new Flower project."""
109
- if project_name is None:
110
- project_name = prompt_text("Please provide the project name")
111
- if not is_valid_project_name(project_name):
112
- project_name = prompt_text(
108
+ """Create new Flower App."""
109
+ if app_name is None:
110
+ app_name = prompt_text("Please provide the app name")
111
+ if not is_valid_project_name(app_name):
112
+ app_name = prompt_text(
113
113
  "Please provide a name that only contains "
114
114
  "characters in {'-', a-zA-Z', '0-9'}",
115
115
  predicate=is_valid_project_name,
116
- default=sanitize_project_name(project_name),
116
+ default=sanitize_project_name(app_name),
117
117
  )
118
118
 
119
119
  # Set project directory path
120
- package_name = re.sub(r"[-_.]+", "-", project_name).lower()
120
+ package_name = re.sub(r"[-_.]+", "-", app_name).lower()
121
121
  import_name = package_name.replace("-", "_")
122
122
  project_dir = Path.cwd() / package_name
123
123
 
124
124
  if project_dir.exists():
125
125
  if not typer.confirm(
126
126
  typer.style(
127
- f"\n💬 {project_name} already exists, do you want to override it?",
127
+ f"\n💬 {app_name} already exists, do you want to override it?",
128
128
  fg=typer.colors.MAGENTA,
129
129
  bold=True,
130
130
  )
@@ -135,20 +135,20 @@ def new(
135
135
  username = prompt_text("Please provide your Flower username")
136
136
 
137
137
  if framework is not None:
138
- framework_str = str(framework.value)
138
+ framework_str_upper = str(framework.value)
139
139
  else:
140
140
  framework_value = prompt_options(
141
141
  "Please select ML framework by typing in the number",
142
- sorted([mlf.value for mlf in MlFramework]),
142
+ [mlf.value for mlf in MlFramework],
143
143
  )
144
144
  selected_value = [
145
145
  name
146
146
  for name, value in vars(MlFramework).items()
147
147
  if value == framework_value
148
148
  ]
149
- framework_str = selected_value[0]
149
+ framework_str_upper = selected_value[0]
150
150
 
151
- framework_str = framework_str.lower()
151
+ framework_str = framework_str_upper.lower()
152
152
 
153
153
  llm_challenge_str = None
154
154
  if framework_str == "flowertune":
@@ -166,16 +166,17 @@ def new(
166
166
 
167
167
  print(
168
168
  typer.style(
169
- f"\n🔨 Creating Flower project {project_name}...",
169
+ f"\n🔨 Creating Flower App {app_name}...",
170
170
  fg=typer.colors.GREEN,
171
171
  bold=True,
172
172
  )
173
173
  )
174
174
 
175
175
  context = {
176
- "project_name": project_name,
177
- "package_name": package_name,
176
+ "framework_str": framework_str_upper,
178
177
  "import_name": import_name.replace("-", "_"),
178
+ "package_name": package_name,
179
+ "project_name": app_name,
179
180
  "username": username,
180
181
  }
181
182
 
@@ -267,8 +268,8 @@ def new(
267
268
 
268
269
  print(
269
270
  typer.style(
270
- "🎊 Project creation successful.\n\n"
271
- "Use the following command to run your project:\n",
271
+ "🎊 Flower App creation successful.\n\n"
272
+ "Use the following command to run your Flower App:\n",
272
273
  fg=typer.colors.GREEN,
273
274
  bold=True,
274
275
  )
@@ -1,4 +1,4 @@
1
- # $project_name
1
+ # $project_name: A Flower / $framework_str app
2
2
 
3
3
  ## Install dependencies
4
4
 
@@ -1 +1 @@
1
- """$project_name."""
1
+ """$project_name: A Flower / $framework_str app."""
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / HuggingFace Transformers app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from flwr.client import ClientApp, NumPyClient
4
4
  from flwr.common import Context
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / JAX app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import jax
4
4
  from flwr.client import NumPyClient, ClientApp
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / MLX app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import mlx.core as mx
4
4
  import mlx.nn as nn
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / NumPy app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from flwr.client import NumPyClient, ClientApp
4
4
  from flwr.common import Context
@@ -1,11 +1,11 @@
1
- """$project_name: A Flower / PyTorch app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
+ import torch
3
4
  from flwr.client import NumPyClient, ClientApp
4
5
  from flwr.common import Context
5
6
 
6
7
  from $import_name.task import (
7
8
  Net,
8
- DEVICE,
9
9
  load_data,
10
10
  get_weights,
11
11
  set_weights,
@@ -21,27 +21,28 @@ class FlowerClient(NumPyClient):
21
21
  self.trainloader = trainloader
22
22
  self.valloader = valloader
23
23
  self.local_epochs = local_epochs
24
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+ self.net.to(self.device)
24
26
 
25
27
  def fit(self, parameters, config):
26
28
  set_weights(self.net, parameters)
27
- results = train(
29
+ train_loss = train(
28
30
  self.net,
29
31
  self.trainloader,
30
- self.valloader,
31
32
  self.local_epochs,
32
- DEVICE,
33
+ self.device,
33
34
  )
34
- return get_weights(self.net), len(self.trainloader.dataset), results
35
+ return get_weights(self.net), len(self.trainloader.dataset), {"train_loss": train_loss}
35
36
 
36
37
  def evaluate(self, parameters, config):
37
38
  set_weights(self.net, parameters)
38
- loss, accuracy = test(self.net, self.valloader)
39
+ loss, accuracy = test(self.net, self.valloader, self.device)
39
40
  return loss, len(self.valloader.dataset), {"accuracy": accuracy}
40
41
 
41
42
 
42
43
  def client_fn(context: Context):
43
44
  # Load model and data
44
- net = Net().to(DEVICE)
45
+ net = Net()
45
46
  partition_id = context.node_config["partition-id"]
46
47
  num_partitions = context.node_config["num-partitions"]
47
48
  trainloader, valloader = load_data(partition_id, num_partitions)
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / Scikit-Learn app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import warnings
4
4
 
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / TensorFlow app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from flwr.client import NumPyClient, ClientApp
4
4
  from flwr.common import Context
@@ -9,13 +9,10 @@ from $import_name.task import load_data, load_model
9
9
  # Define Flower Client and client_fn
10
10
  class FlowerClient(NumPyClient):
11
11
  def __init__(
12
- self, model, x_train, y_train, x_test, y_test, epochs, batch_size, verbose
12
+ self, model, data, epochs, batch_size, verbose
13
13
  ):
14
14
  self.model = model
15
- self.x_train = x_train
16
- self.y_train = y_train
17
- self.x_test = x_test
18
- self.y_test = y_test
15
+ self.x_train, self.y_train, self.x_test, self.y_test = data
19
16
  self.epochs = epochs
20
17
  self.batch_size = batch_size
21
18
  self.verbose = verbose
@@ -46,14 +43,14 @@ def client_fn(context: Context):
46
43
 
47
44
  partition_id = context.node_config["partition-id"]
48
45
  num_partitions = context.node_config["num-partitions"]
49
- x_train, y_train, x_test, y_test = load_data(partition_id, num_partitions)
46
+ data = load_data(partition_id, num_partitions)
50
47
  epochs = context.run_config["local-epochs"]
51
48
  batch_size = context.run_config["batch-size"]
52
49
  verbose = context.run_config.get("verbose")
53
50
 
54
51
  # Return Client instance
55
52
  return FlowerClient(
56
- net, x_train, y_train, x_test, y_test, epochs, batch_size, verbose
53
+ net, data, epochs, batch_size, verbose
57
54
  ).to_client()
58
55
 
59
56
 
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / HuggingFace Transformers app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from flwr.common import Context
4
4
  from flwr.server.strategy import FedAvg
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / JAX app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from flwr.common import Context
4
4
  from flwr.server.strategy import FedAvg
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / MLX app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from flwr.common import Context
4
4
  from flwr.server import ServerApp, ServerAppComponents, ServerConfig
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / NumPy app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from flwr.common import Context
4
4
  from flwr.server import ServerApp, ServerAppComponents, ServerConfig
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / PyTorch app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from flwr.common import Context, ndarrays_to_parameters
4
4
  from flwr.server import ServerApp, ServerAppComponents, ServerConfig
@@ -7,17 +7,18 @@ from flwr.server.strategy import FedAvg
7
7
  from $import_name.task import Net, get_weights
8
8
 
9
9
 
10
- # Initialize model parameters
11
- ndarrays = get_weights(Net())
12
- parameters = ndarrays_to_parameters(ndarrays)
13
-
14
10
  def server_fn(context: Context):
15
11
  # Read from config
16
12
  num_rounds = context.run_config["num-server-rounds"]
13
+ fraction_fit = context.run_config["fraction-fit"]
14
+
15
+ # Initialize model parameters
16
+ ndarrays = get_weights(Net())
17
+ parameters = ndarrays_to_parameters(ndarrays)
17
18
 
18
19
  # Define strategy
19
20
  strategy = FedAvg(
20
- fraction_fit=1.0,
21
+ fraction_fit=fraction_fit,
21
22
  fraction_evaluate=1.0,
22
23
  min_available_clients=2,
23
24
  initial_parameters=parameters,
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / Scikit-Learn app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from flwr.common import Context
4
4
  from flwr.server import ServerApp, ServerAppComponents, ServerConfig
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / TensorFlow app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from flwr.common import Context, ndarrays_to_parameters
4
4
  from flwr.server import ServerApp, ServerAppComponents, ServerConfig
@@ -6,15 +6,14 @@ from flwr.server.strategy import FedAvg
6
6
 
7
7
  from $import_name.task import load_model
8
8
 
9
- # Define config
10
- config = ServerConfig(num_rounds=3)
11
-
12
- parameters = ndarrays_to_parameters(load_model().get_weights())
13
9
 
14
10
  def server_fn(context: Context):
15
11
  # Read from config
16
12
  num_rounds = context.run_config["num-server-rounds"]
17
13
 
14
+ # Get parameters to initialize global model
15
+ parameters = ndarrays_to_parameters(load_model().get_weights())
16
+
18
17
  # Define strategy
19
18
  strategy = strategy = FedAvg(
20
19
  fraction_fit=1.0,
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / HuggingFace Transformers app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import warnings
4
4
  from collections import OrderedDict
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / JAX app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
@@ -33,7 +33,7 @@ def train(params, grad_fn, X, y):
33
33
  num_examples = X.shape[0]
34
34
  for epochs in range(50):
35
35
  grads = grad_fn(params, X, y)
36
- params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
36
+ params = jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)
37
37
  loss = loss_fn(params, X, y)
38
38
  return params, loss, num_examples
39
39
 
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / MLX app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import mlx.core as mx
4
4
  import mlx.nn as nn
@@ -56,6 +56,7 @@ def load_data(partition_id: int, num_partitions: int):
56
56
  fds = FederatedDataset(
57
57
  dataset="ylecun/mnist",
58
58
  partitioners={"train": partitioner},
59
+ trust_remote_code=True,
59
60
  )
60
61
  partition = fds.load_partition(partition_id)
61
62
  partition_splits = partition.train_test_split(test_size=0.2, seed=42)
@@ -1,4 +1,4 @@
1
- """$project_name: A Flower / PyTorch app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  from collections import OrderedDict
4
4
 
@@ -11,9 +11,6 @@ from flwr_datasets import FederatedDataset
11
11
  from flwr_datasets.partitioner import IidPartitioner
12
12
 
13
13
 
14
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15
-
16
-
17
14
  class Net(nn.Module):
18
15
  """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
19
16
 
@@ -66,44 +63,41 @@ def load_data(partition_id: int, num_partitions: int):
66
63
  return trainloader, testloader
67
64
 
68
65
 
69
- def train(net, trainloader, valloader, epochs, device):
66
+ def train(net, trainloader, epochs, device):
70
67
  """Train the model on the training set."""
71
68
  net.to(device) # move model to GPU if available
72
69
  criterion = torch.nn.CrossEntropyLoss().to(device)
73
- optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
70
+ optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
74
71
  net.train()
72
+ running_loss = 0.0
75
73
  for _ in range(epochs):
76
74
  for batch in trainloader:
77
75
  images = batch["img"]
78
76
  labels = batch["label"]
79
77
  optimizer.zero_grad()
80
- criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
78
+ loss = criterion(net(images.to(device)), labels.to(device))
79
+ loss.backward()
81
80
  optimizer.step()
81
+ running_loss += loss.item()
82
82
 
83
- train_loss, train_acc = test(net, trainloader)
84
- val_loss, val_acc = test(net, valloader)
85
-
86
- results = {
87
- "train_loss": train_loss,
88
- "train_accuracy": train_acc,
89
- "val_loss": val_loss,
90
- "val_accuracy": val_acc,
91
- }
92
- return results
83
+ avg_trainloss = running_loss / len(trainloader)
84
+ return avg_trainloss
93
85
 
94
86
 
95
- def test(net, testloader):
87
+ def test(net, testloader, device):
96
88
  """Validate the model on the test set."""
89
+ net.to(device)
97
90
  criterion = torch.nn.CrossEntropyLoss()
98
91
  correct, loss = 0, 0.0
99
92
  with torch.no_grad():
100
93
  for batch in testloader:
101
- images = batch["img"].to(DEVICE)
102
- labels = batch["label"].to(DEVICE)
94
+ images = batch["img"].to(device)
95
+ labels = batch["label"].to(device)
103
96
  outputs = net(images)
104
97
  loss += criterion(outputs, labels).item()
105
98
  correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
106
99
  accuracy = correct / len(testloader.dataset)
100
+ loss = loss / len(testloader)
107
101
  return loss, accuracy
108
102
 
109
103
 
@@ -1,8 +1,9 @@
1
- """$project_name: A Flower / TensorFlow app."""
1
+ """$project_name: A Flower / $framework_str app."""
2
2
 
3
3
  import os
4
4
 
5
- import tensorflow as tf
5
+ import keras
6
+ from keras import layers
6
7
  from flwr_datasets import FederatedDataset
7
8
  from flwr_datasets.partitioner import IidPartitioner
8
9
 
@@ -12,8 +13,19 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
12
13
 
13
14
 
14
15
  def load_model():
15
- # Load model and data (MobileNetV2, CIFAR-10)
16
- model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
16
+ # Define a simple CNN for CIFAR-10 and set Adam optimizer
17
+ model = keras.Sequential(
18
+ [
19
+ keras.Input(shape=(32, 32, 3)),
20
+ layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
21
+ layers.MaxPooling2D(pool_size=(2, 2)),
22
+ layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
23
+ layers.MaxPooling2D(pool_size=(2, 2)),
24
+ layers.Flatten(),
25
+ layers.Dropout(0.5),
26
+ layers.Dense(10, activation="softmax"),
27
+ ]
28
+ )
17
29
  model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
18
30
  return model
19
31
 
@@ -8,8 +8,8 @@ version = "1.0.0"
8
8
  description = ""
9
9
  license = "Apache-2.0"
10
10
  dependencies = [
11
- "flwr[simulation]>=1.9.0,<2.0",
12
- "flwr-datasets>=0.0.2,<1.0.0",
11
+ "flwr[simulation]>=1.10.0",
12
+ "flwr-datasets>=0.3.0",
13
13
  "torch==2.2.1",
14
14
  "transformers>=4.30.0,<5.0",
15
15
  "evaluate>=0.4.0,<1.0",