flwr-nightly 1.11.0.dev20240724__py3-none-any.whl → 1.11.0.dev20240811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +22 -20
- flwr/cli/config_utils.py +27 -8
- flwr/cli/new/new.py +23 -22
- flwr/cli/new/templates/app/README.md.tpl +1 -1
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +9 -8
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +5 -8
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.jax.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +7 -6
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +4 -5
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.jax.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -1
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -20
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +16 -4
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -2
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/run/run.py +15 -12
- flwr/client/grpc_rere_client/grpc_adapter.py +7 -0
- flwr/client/supernode/app.py +36 -28
- flwr/common/config.py +30 -0
- flwr/common/typing.py +8 -0
- flwr/proto/driver_pb2.py +22 -21
- flwr/proto/driver_pb2.pyi +7 -1
- flwr/proto/driver_pb2_grpc.py +35 -0
- flwr/proto/driver_pb2_grpc.pyi +14 -0
- flwr/proto/fab_pb2.py +6 -6
- flwr/proto/fab_pb2.pyi +8 -8
- flwr/proto/fleet_pb2.py +28 -27
- flwr/proto/fleet_pb2_grpc.py +35 -0
- flwr/proto/fleet_pb2_grpc.pyi +14 -0
- flwr/proto/run_pb2.py +8 -8
- flwr/proto/run_pb2.pyi +4 -1
- flwr/server/run_serverapp.py +28 -46
- flwr/server/superlink/driver/driver_servicer.py +7 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +7 -0
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +4 -35
- flwr/server/superlink/fleet/vce/vce_api.py +3 -3
- flwr/superexec/simulation.py +15 -3
- {flwr_nightly-1.11.0.dev20240724.dist-info → flwr_nightly-1.11.0.dev20240811.dist-info}/METADATA +2 -2
- {flwr_nightly-1.11.0.dev20240724.dist-info → flwr_nightly-1.11.0.dev20240811.dist-info}/RECORD +59 -59
- {flwr_nightly-1.11.0.dev20240724.dist-info → flwr_nightly-1.11.0.dev20240811.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240724.dist-info → flwr_nightly-1.11.0.dev20240811.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.11.0.dev20240724.dist-info → flwr_nightly-1.11.0.dev20240811.dist-info}/entry_points.txt +0 -0
flwr/cli/build.py
CHANGED
|
@@ -30,32 +30,34 @@ from .utils import get_sha256_hash, is_valid_project_name
|
|
|
30
30
|
|
|
31
31
|
# pylint: disable=too-many-locals
|
|
32
32
|
def build(
|
|
33
|
-
|
|
33
|
+
app: Annotated[
|
|
34
34
|
Optional[Path],
|
|
35
|
-
typer.Option(help="Path of the Flower
|
|
35
|
+
typer.Option(help="Path of the Flower App to bundle into a FAB"),
|
|
36
36
|
] = None,
|
|
37
37
|
) -> str:
|
|
38
|
-
"""Build a Flower
|
|
38
|
+
"""Build a Flower App into a Flower App Bundle (FAB).
|
|
39
39
|
|
|
40
|
-
You can run ``flwr build`` without any arguments to bundle the
|
|
41
|
-
|
|
42
|
-
|
|
40
|
+
You can run ``flwr build`` without any arguments to bundle the app located in the
|
|
41
|
+
current directory. Alternatively, you can you can specify a path using the ``--app``
|
|
42
|
+
option to bundle an app located at the provided path. For example:
|
|
43
|
+
|
|
44
|
+
``flwr build --app ./apps/flower-hello-world``.
|
|
43
45
|
"""
|
|
44
|
-
if
|
|
45
|
-
|
|
46
|
+
if app is None:
|
|
47
|
+
app = Path.cwd()
|
|
46
48
|
|
|
47
|
-
|
|
48
|
-
if not
|
|
49
|
+
app = app.resolve()
|
|
50
|
+
if not app.is_dir():
|
|
49
51
|
typer.secho(
|
|
50
|
-
f"❌ The path {
|
|
52
|
+
f"❌ The path {app} is not a valid path to a Flower app.",
|
|
51
53
|
fg=typer.colors.RED,
|
|
52
54
|
bold=True,
|
|
53
55
|
)
|
|
54
56
|
raise typer.Exit(code=1)
|
|
55
57
|
|
|
56
|
-
if not is_valid_project_name(
|
|
58
|
+
if not is_valid_project_name(app.name):
|
|
57
59
|
typer.secho(
|
|
58
|
-
f"❌ The project name {
|
|
60
|
+
f"❌ The project name {app.name} is invalid, "
|
|
59
61
|
"a valid project name must start with a letter or an underscore, "
|
|
60
62
|
"and can only contain letters, digits, and underscores.",
|
|
61
63
|
fg=typer.colors.RED,
|
|
@@ -63,7 +65,7 @@ def build(
|
|
|
63
65
|
)
|
|
64
66
|
raise typer.Exit(code=1)
|
|
65
67
|
|
|
66
|
-
conf, errors, warnings = load_and_validate(
|
|
68
|
+
conf, errors, warnings = load_and_validate(app / "pyproject.toml")
|
|
67
69
|
if conf is None:
|
|
68
70
|
typer.secho(
|
|
69
71
|
"Project configuration could not be loaded.\npyproject.toml is invalid:\n"
|
|
@@ -82,12 +84,12 @@ def build(
|
|
|
82
84
|
)
|
|
83
85
|
|
|
84
86
|
# Load .gitignore rules if present
|
|
85
|
-
ignore_spec = _load_gitignore(
|
|
87
|
+
ignore_spec = _load_gitignore(app)
|
|
86
88
|
|
|
87
89
|
# Set the name of the zip file
|
|
88
90
|
fab_filename = (
|
|
89
91
|
f"{conf['tool']['flwr']['app']['publisher']}"
|
|
90
|
-
f".{
|
|
92
|
+
f".{app.name}"
|
|
91
93
|
f".{conf['project']['version'].replace('.', '-')}.fab"
|
|
92
94
|
)
|
|
93
95
|
list_file_content = ""
|
|
@@ -108,7 +110,7 @@ def build(
|
|
|
108
110
|
fab_file.writestr("pyproject.toml", toml_contents)
|
|
109
111
|
|
|
110
112
|
# Continue with adding other files
|
|
111
|
-
for root, _, files in os.walk(
|
|
113
|
+
for root, _, files in os.walk(app, topdown=True):
|
|
112
114
|
files = [
|
|
113
115
|
f
|
|
114
116
|
for f in files
|
|
@@ -120,7 +122,7 @@ def build(
|
|
|
120
122
|
|
|
121
123
|
for file in files:
|
|
122
124
|
file_path = Path(root) / file
|
|
123
|
-
archive_path = file_path.relative_to(
|
|
125
|
+
archive_path = file_path.relative_to(app)
|
|
124
126
|
fab_file.write(file_path, archive_path)
|
|
125
127
|
|
|
126
128
|
# Calculate file info
|
|
@@ -138,9 +140,9 @@ def build(
|
|
|
138
140
|
return fab_filename
|
|
139
141
|
|
|
140
142
|
|
|
141
|
-
def _load_gitignore(
|
|
143
|
+
def _load_gitignore(app: Path) -> pathspec.PathSpec:
|
|
142
144
|
"""Load and parse .gitignore file, returning a pathspec."""
|
|
143
|
-
gitignore_path =
|
|
145
|
+
gitignore_path = app / ".gitignore"
|
|
144
146
|
patterns = ["__pycache__/"] # Default pattern
|
|
145
147
|
if gitignore_path.exists():
|
|
146
148
|
with open(gitignore_path, encoding="UTF-8") as file:
|
flwr/cli/config_utils.py
CHANGED
|
@@ -25,8 +25,8 @@ from flwr.common import object_ref
|
|
|
25
25
|
from flwr.common.typing import UserConfigValue
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def
|
|
29
|
-
"""Extract the
|
|
28
|
+
def get_fab_config(fab_file: Union[Path, bytes]) -> Dict[str, Any]:
|
|
29
|
+
"""Extract the config from a FAB file or path.
|
|
30
30
|
|
|
31
31
|
Parameters
|
|
32
32
|
----------
|
|
@@ -36,8 +36,8 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
|
|
|
36
36
|
|
|
37
37
|
Returns
|
|
38
38
|
-------
|
|
39
|
-
|
|
40
|
-
The `
|
|
39
|
+
Dict[str, Any]
|
|
40
|
+
The `config` of the given Flower App Bundle.
|
|
41
41
|
"""
|
|
42
42
|
fab_file_archive: Union[Path, IO[bytes]]
|
|
43
43
|
if isinstance(fab_file, bytes):
|
|
@@ -59,10 +59,29 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
|
|
|
59
59
|
if not is_valid:
|
|
60
60
|
raise ValueError(errors)
|
|
61
61
|
|
|
62
|
-
return
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
62
|
+
return conf
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
|
|
66
|
+
"""Extract the fab_id and the fab_version from a FAB file or path.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
fab_file : Union[Path, bytes]
|
|
71
|
+
The Flower App Bundle file to validate and extract the metadata from.
|
|
72
|
+
It can either be a path to the file or the file itself as bytes.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
Tuple[str, str]
|
|
77
|
+
The `fab_version` and `fab_id` of the given Flower App Bundle.
|
|
78
|
+
"""
|
|
79
|
+
conf = get_fab_config(fab_file)
|
|
80
|
+
|
|
81
|
+
return (
|
|
82
|
+
conf["project"]["version"],
|
|
83
|
+
f"{conf['tool']['flwr']['app']['publisher']}/{conf['project']['name']}",
|
|
84
|
+
)
|
|
66
85
|
|
|
67
86
|
|
|
68
87
|
def load_and_validate(
|
flwr/cli/new/new.py
CHANGED
|
@@ -34,13 +34,13 @@ from ..utils import (
|
|
|
34
34
|
class MlFramework(str, Enum):
|
|
35
35
|
"""Available frameworks."""
|
|
36
36
|
|
|
37
|
-
NUMPY = "NumPy"
|
|
38
37
|
PYTORCH = "PyTorch"
|
|
39
38
|
TENSORFLOW = "TensorFlow"
|
|
40
|
-
|
|
39
|
+
SKLEARN = "sklearn"
|
|
41
40
|
HUGGINGFACE = "HuggingFace"
|
|
41
|
+
JAX = "JAX"
|
|
42
42
|
MLX = "MLX"
|
|
43
|
-
|
|
43
|
+
NUMPY = "NumPy"
|
|
44
44
|
FLOWERTUNE = "FlowerTune"
|
|
45
45
|
|
|
46
46
|
|
|
@@ -92,9 +92,9 @@ def render_and_create(file_path: Path, template: str, context: Dict[str, str]) -
|
|
|
92
92
|
|
|
93
93
|
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
|
|
94
94
|
def new(
|
|
95
|
-
|
|
95
|
+
app_name: Annotated[
|
|
96
96
|
Optional[str],
|
|
97
|
-
typer.Argument(
|
|
97
|
+
typer.Argument(help="The name of the Flower App"),
|
|
98
98
|
] = None,
|
|
99
99
|
framework: Annotated[
|
|
100
100
|
Optional[MlFramework],
|
|
@@ -105,26 +105,26 @@ def new(
|
|
|
105
105
|
typer.Option(case_sensitive=False, help="The Flower username of the author"),
|
|
106
106
|
] = None,
|
|
107
107
|
) -> None:
|
|
108
|
-
"""Create new Flower
|
|
109
|
-
if
|
|
110
|
-
|
|
111
|
-
if not is_valid_project_name(
|
|
112
|
-
|
|
108
|
+
"""Create new Flower App."""
|
|
109
|
+
if app_name is None:
|
|
110
|
+
app_name = prompt_text("Please provide the app name")
|
|
111
|
+
if not is_valid_project_name(app_name):
|
|
112
|
+
app_name = prompt_text(
|
|
113
113
|
"Please provide a name that only contains "
|
|
114
114
|
"characters in {'-', a-zA-Z', '0-9'}",
|
|
115
115
|
predicate=is_valid_project_name,
|
|
116
|
-
default=sanitize_project_name(
|
|
116
|
+
default=sanitize_project_name(app_name),
|
|
117
117
|
)
|
|
118
118
|
|
|
119
119
|
# Set project directory path
|
|
120
|
-
package_name = re.sub(r"[-_.]+", "-",
|
|
120
|
+
package_name = re.sub(r"[-_.]+", "-", app_name).lower()
|
|
121
121
|
import_name = package_name.replace("-", "_")
|
|
122
122
|
project_dir = Path.cwd() / package_name
|
|
123
123
|
|
|
124
124
|
if project_dir.exists():
|
|
125
125
|
if not typer.confirm(
|
|
126
126
|
typer.style(
|
|
127
|
-
f"\n💬 {
|
|
127
|
+
f"\n💬 {app_name} already exists, do you want to override it?",
|
|
128
128
|
fg=typer.colors.MAGENTA,
|
|
129
129
|
bold=True,
|
|
130
130
|
)
|
|
@@ -135,20 +135,20 @@ def new(
|
|
|
135
135
|
username = prompt_text("Please provide your Flower username")
|
|
136
136
|
|
|
137
137
|
if framework is not None:
|
|
138
|
-
|
|
138
|
+
framework_str_upper = str(framework.value)
|
|
139
139
|
else:
|
|
140
140
|
framework_value = prompt_options(
|
|
141
141
|
"Please select ML framework by typing in the number",
|
|
142
|
-
|
|
142
|
+
[mlf.value for mlf in MlFramework],
|
|
143
143
|
)
|
|
144
144
|
selected_value = [
|
|
145
145
|
name
|
|
146
146
|
for name, value in vars(MlFramework).items()
|
|
147
147
|
if value == framework_value
|
|
148
148
|
]
|
|
149
|
-
|
|
149
|
+
framework_str_upper = selected_value[0]
|
|
150
150
|
|
|
151
|
-
framework_str =
|
|
151
|
+
framework_str = framework_str_upper.lower()
|
|
152
152
|
|
|
153
153
|
llm_challenge_str = None
|
|
154
154
|
if framework_str == "flowertune":
|
|
@@ -166,16 +166,17 @@ def new(
|
|
|
166
166
|
|
|
167
167
|
print(
|
|
168
168
|
typer.style(
|
|
169
|
-
f"\n🔨 Creating Flower
|
|
169
|
+
f"\n🔨 Creating Flower App {app_name}...",
|
|
170
170
|
fg=typer.colors.GREEN,
|
|
171
171
|
bold=True,
|
|
172
172
|
)
|
|
173
173
|
)
|
|
174
174
|
|
|
175
175
|
context = {
|
|
176
|
-
"
|
|
177
|
-
"package_name": package_name,
|
|
176
|
+
"framework_str": framework_str_upper,
|
|
178
177
|
"import_name": import_name.replace("-", "_"),
|
|
178
|
+
"package_name": package_name,
|
|
179
|
+
"project_name": app_name,
|
|
179
180
|
"username": username,
|
|
180
181
|
}
|
|
181
182
|
|
|
@@ -267,8 +268,8 @@ def new(
|
|
|
267
268
|
|
|
268
269
|
print(
|
|
269
270
|
typer.style(
|
|
270
|
-
"🎊
|
|
271
|
-
"Use the following command to run your
|
|
271
|
+
"🎊 Flower App creation successful.\n\n"
|
|
272
|
+
"Use the following command to run your Flower App:\n",
|
|
272
273
|
fg=typer.colors.GREEN,
|
|
273
274
|
bold=True,
|
|
274
275
|
)
|
|
@@ -1 +1 @@
|
|
|
1
|
-
"""$project_name."""
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
+
import torch
|
|
3
4
|
from flwr.client import NumPyClient, ClientApp
|
|
4
5
|
from flwr.common import Context
|
|
5
6
|
|
|
6
7
|
from $import_name.task import (
|
|
7
8
|
Net,
|
|
8
|
-
DEVICE,
|
|
9
9
|
load_data,
|
|
10
10
|
get_weights,
|
|
11
11
|
set_weights,
|
|
@@ -21,27 +21,28 @@ class FlowerClient(NumPyClient):
|
|
|
21
21
|
self.trainloader = trainloader
|
|
22
22
|
self.valloader = valloader
|
|
23
23
|
self.local_epochs = local_epochs
|
|
24
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
25
|
+
self.net.to(self.device)
|
|
24
26
|
|
|
25
27
|
def fit(self, parameters, config):
|
|
26
28
|
set_weights(self.net, parameters)
|
|
27
|
-
|
|
29
|
+
train_loss = train(
|
|
28
30
|
self.net,
|
|
29
31
|
self.trainloader,
|
|
30
|
-
self.valloader,
|
|
31
32
|
self.local_epochs,
|
|
32
|
-
|
|
33
|
+
self.device,
|
|
33
34
|
)
|
|
34
|
-
return get_weights(self.net), len(self.trainloader.dataset),
|
|
35
|
+
return get_weights(self.net), len(self.trainloader.dataset), {"train_loss": train_loss}
|
|
35
36
|
|
|
36
37
|
def evaluate(self, parameters, config):
|
|
37
38
|
set_weights(self.net, parameters)
|
|
38
|
-
loss, accuracy = test(self.net, self.valloader)
|
|
39
|
+
loss, accuracy = test(self.net, self.valloader, self.device)
|
|
39
40
|
return loss, len(self.valloader.dataset), {"accuracy": accuracy}
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
def client_fn(context: Context):
|
|
43
44
|
# Load model and data
|
|
44
|
-
net = Net()
|
|
45
|
+
net = Net()
|
|
45
46
|
partition_id = context.node_config["partition-id"]
|
|
46
47
|
num_partitions = context.node_config["num-partitions"]
|
|
47
48
|
trainloader, valloader = load_data(partition_id, num_partitions)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.client import NumPyClient, ClientApp
|
|
4
4
|
from flwr.common import Context
|
|
@@ -9,13 +9,10 @@ from $import_name.task import load_data, load_model
|
|
|
9
9
|
# Define Flower Client and client_fn
|
|
10
10
|
class FlowerClient(NumPyClient):
|
|
11
11
|
def __init__(
|
|
12
|
-
self, model,
|
|
12
|
+
self, model, data, epochs, batch_size, verbose
|
|
13
13
|
):
|
|
14
14
|
self.model = model
|
|
15
|
-
self.x_train =
|
|
16
|
-
self.y_train = y_train
|
|
17
|
-
self.x_test = x_test
|
|
18
|
-
self.y_test = y_test
|
|
15
|
+
self.x_train, self.y_train, self.x_test, self.y_test = data
|
|
19
16
|
self.epochs = epochs
|
|
20
17
|
self.batch_size = batch_size
|
|
21
18
|
self.verbose = verbose
|
|
@@ -46,14 +43,14 @@ def client_fn(context: Context):
|
|
|
46
43
|
|
|
47
44
|
partition_id = context.node_config["partition-id"]
|
|
48
45
|
num_partitions = context.node_config["num-partitions"]
|
|
49
|
-
|
|
46
|
+
data = load_data(partition_id, num_partitions)
|
|
50
47
|
epochs = context.run_config["local-epochs"]
|
|
51
48
|
batch_size = context.run_config["batch-size"]
|
|
52
49
|
verbose = context.run_config.get("verbose")
|
|
53
50
|
|
|
54
51
|
# Return Client instance
|
|
55
52
|
return FlowerClient(
|
|
56
|
-
net,
|
|
53
|
+
net, data, epochs, batch_size, verbose
|
|
57
54
|
).to_client()
|
|
58
55
|
|
|
59
56
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.common import Context, ndarrays_to_parameters
|
|
4
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
@@ -7,17 +7,18 @@ from flwr.server.strategy import FedAvg
|
|
|
7
7
|
from $import_name.task import Net, get_weights
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
# Initialize model parameters
|
|
11
|
-
ndarrays = get_weights(Net())
|
|
12
|
-
parameters = ndarrays_to_parameters(ndarrays)
|
|
13
|
-
|
|
14
10
|
def server_fn(context: Context):
|
|
15
11
|
# Read from config
|
|
16
12
|
num_rounds = context.run_config["num-server-rounds"]
|
|
13
|
+
fraction_fit = context.run_config["fraction-fit"]
|
|
14
|
+
|
|
15
|
+
# Initialize model parameters
|
|
16
|
+
ndarrays = get_weights(Net())
|
|
17
|
+
parameters = ndarrays_to_parameters(ndarrays)
|
|
17
18
|
|
|
18
19
|
# Define strategy
|
|
19
20
|
strategy = FedAvg(
|
|
20
|
-
fraction_fit=
|
|
21
|
+
fraction_fit=fraction_fit,
|
|
21
22
|
fraction_evaluate=1.0,
|
|
22
23
|
min_available_clients=2,
|
|
23
24
|
initial_parameters=parameters,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from flwr.common import Context, ndarrays_to_parameters
|
|
4
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
@@ -6,15 +6,14 @@ from flwr.server.strategy import FedAvg
|
|
|
6
6
|
|
|
7
7
|
from $import_name.task import load_model
|
|
8
8
|
|
|
9
|
-
# Define config
|
|
10
|
-
config = ServerConfig(num_rounds=3)
|
|
11
|
-
|
|
12
|
-
parameters = ndarrays_to_parameters(load_model().get_weights())
|
|
13
9
|
|
|
14
10
|
def server_fn(context: Context):
|
|
15
11
|
# Read from config
|
|
16
12
|
num_rounds = context.run_config["num-server-rounds"]
|
|
17
13
|
|
|
14
|
+
# Get parameters to initialize global model
|
|
15
|
+
parameters = ndarrays_to_parameters(load_model().get_weights())
|
|
16
|
+
|
|
18
17
|
# Define strategy
|
|
19
18
|
strategy = strategy = FedAvg(
|
|
20
19
|
fraction_fit=1.0,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
4
|
import jax.numpy as jnp
|
|
@@ -33,7 +33,7 @@ def train(params, grad_fn, X, y):
|
|
|
33
33
|
num_examples = X.shape[0]
|
|
34
34
|
for epochs in range(50):
|
|
35
35
|
grads = grad_fn(params, X, y)
|
|
36
|
-
params = jax.
|
|
36
|
+
params = jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)
|
|
37
37
|
loss = loss_fn(params, X, y)
|
|
38
38
|
return params, loss, num_examples
|
|
39
39
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import mlx.core as mx
|
|
4
4
|
import mlx.nn as nn
|
|
@@ -56,6 +56,7 @@ def load_data(partition_id: int, num_partitions: int):
|
|
|
56
56
|
fds = FederatedDataset(
|
|
57
57
|
dataset="ylecun/mnist",
|
|
58
58
|
partitioners={"train": partitioner},
|
|
59
|
+
trust_remote_code=True,
|
|
59
60
|
)
|
|
60
61
|
partition = fds.load_partition(partition_id)
|
|
61
62
|
partition_splits = partition.train_test_split(test_size=0.2, seed=42)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
from collections import OrderedDict
|
|
4
4
|
|
|
@@ -11,9 +11,6 @@ from flwr_datasets import FederatedDataset
|
|
|
11
11
|
from flwr_datasets.partitioner import IidPartitioner
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
15
|
-
|
|
16
|
-
|
|
17
14
|
class Net(nn.Module):
|
|
18
15
|
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
|
|
19
16
|
|
|
@@ -66,44 +63,41 @@ def load_data(partition_id: int, num_partitions: int):
|
|
|
66
63
|
return trainloader, testloader
|
|
67
64
|
|
|
68
65
|
|
|
69
|
-
def train(net, trainloader,
|
|
66
|
+
def train(net, trainloader, epochs, device):
|
|
70
67
|
"""Train the model on the training set."""
|
|
71
68
|
net.to(device) # move model to GPU if available
|
|
72
69
|
criterion = torch.nn.CrossEntropyLoss().to(device)
|
|
73
|
-
optimizer = torch.optim.SGD(net.parameters(), lr=0.
|
|
70
|
+
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
|
|
74
71
|
net.train()
|
|
72
|
+
running_loss = 0.0
|
|
75
73
|
for _ in range(epochs):
|
|
76
74
|
for batch in trainloader:
|
|
77
75
|
images = batch["img"]
|
|
78
76
|
labels = batch["label"]
|
|
79
77
|
optimizer.zero_grad()
|
|
80
|
-
criterion(net(images.to(
|
|
78
|
+
loss = criterion(net(images.to(device)), labels.to(device))
|
|
79
|
+
loss.backward()
|
|
81
80
|
optimizer.step()
|
|
81
|
+
running_loss += loss.item()
|
|
82
82
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
results = {
|
|
87
|
-
"train_loss": train_loss,
|
|
88
|
-
"train_accuracy": train_acc,
|
|
89
|
-
"val_loss": val_loss,
|
|
90
|
-
"val_accuracy": val_acc,
|
|
91
|
-
}
|
|
92
|
-
return results
|
|
83
|
+
avg_trainloss = running_loss / len(trainloader)
|
|
84
|
+
return avg_trainloss
|
|
93
85
|
|
|
94
86
|
|
|
95
|
-
def test(net, testloader):
|
|
87
|
+
def test(net, testloader, device):
|
|
96
88
|
"""Validate the model on the test set."""
|
|
89
|
+
net.to(device)
|
|
97
90
|
criterion = torch.nn.CrossEntropyLoss()
|
|
98
91
|
correct, loss = 0, 0.0
|
|
99
92
|
with torch.no_grad():
|
|
100
93
|
for batch in testloader:
|
|
101
|
-
images = batch["img"].to(
|
|
102
|
-
labels = batch["label"].to(
|
|
94
|
+
images = batch["img"].to(device)
|
|
95
|
+
labels = batch["label"].to(device)
|
|
103
96
|
outputs = net(images)
|
|
104
97
|
loss += criterion(outputs, labels).item()
|
|
105
98
|
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
|
|
106
99
|
accuracy = correct / len(testloader.dataset)
|
|
100
|
+
loss = loss / len(testloader)
|
|
107
101
|
return loss, accuracy
|
|
108
102
|
|
|
109
103
|
|
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
"""$project_name: A Flower /
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
|
|
5
|
-
import
|
|
5
|
+
import keras
|
|
6
|
+
from keras import layers
|
|
6
7
|
from flwr_datasets import FederatedDataset
|
|
7
8
|
from flwr_datasets.partitioner import IidPartitioner
|
|
8
9
|
|
|
@@ -12,8 +13,19 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def load_model():
|
|
15
|
-
#
|
|
16
|
-
model =
|
|
16
|
+
# Define a simple CNN for CIFAR-10 and set Adam optimizer
|
|
17
|
+
model = keras.Sequential(
|
|
18
|
+
[
|
|
19
|
+
keras.Input(shape=(32, 32, 3)),
|
|
20
|
+
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
|
|
21
|
+
layers.MaxPooling2D(pool_size=(2, 2)),
|
|
22
|
+
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
|
|
23
|
+
layers.MaxPooling2D(pool_size=(2, 2)),
|
|
24
|
+
layers.Flatten(),
|
|
25
|
+
layers.Dropout(0.5),
|
|
26
|
+
layers.Dense(10, activation="softmax"),
|
|
27
|
+
]
|
|
28
|
+
)
|
|
17
29
|
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
|
|
18
30
|
return model
|
|
19
31
|
|
|
@@ -8,8 +8,8 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
12
|
-
"flwr-datasets>=0.
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
|
+
"flwr-datasets>=0.3.0",
|
|
13
13
|
"torch==2.2.1",
|
|
14
14
|
"transformers>=4.30.0,<5.0",
|
|
15
15
|
"evaluate>=0.4.0,<1.0",
|