flwr-nightly 1.12.0.dev20241007__py3-none-any.whl → 1.12.0.dev20241010__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +60 -29
- flwr/cli/config_utils.py +10 -0
- flwr/cli/install.py +60 -20
- flwr/cli/new/new.py +2 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +11 -17
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +16 -36
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +4 -5
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +8 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +14 -48
- flwr/cli/new/templates/app/code/server.jax.py.tpl +9 -3
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +13 -2
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +7 -2
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +13 -1
- flwr/cli/new/templates/app/code/task.jax.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +1 -1
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +7 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +3 -3
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +67 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +3 -2
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +2 -0
- flwr/cli/run/run.py +5 -5
- flwr/client/app.py +13 -3
- flwr/client/clientapp/app.py +5 -2
- flwr/client/clientapp/utils.py +11 -5
- flwr/client/grpc_rere_client/connection.py +3 -0
- flwr/common/config.py +18 -5
- flwr/common/constant.py +3 -0
- flwr/common/message.py +5 -0
- flwr/common/recordset_compat.py +10 -0
- flwr/common/retry_invoker.py +15 -0
- flwr/server/client_manager.py +2 -0
- flwr/server/compat/driver_client_proxy.py +15 -29
- flwr/server/driver/inmemory_driver.py +6 -2
- flwr/server/run_serverapp.py +11 -13
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +26 -8
- flwr/server/superlink/state/sqlite_state.py +46 -11
- flwr/server/superlink/state/state.py +1 -7
- flwr/server/superlink/state/utils.py +0 -10
- {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/METADATA +1 -1
- {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/RECORD +49 -47
- {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.12.0.dev20241007.dist-info → flwr_nightly-1.12.0.dev20241010.dist-info}/entry_points.txt +0 -0
flwr/cli/build.py
CHANGED
|
@@ -14,26 +14,50 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower command line interface `build` command."""
|
|
16
16
|
|
|
17
|
+
import hashlib
|
|
17
18
|
import os
|
|
19
|
+
import shutil
|
|
20
|
+
import tempfile
|
|
18
21
|
import zipfile
|
|
19
22
|
from pathlib import Path
|
|
20
|
-
from typing import Annotated, Optional
|
|
23
|
+
from typing import Annotated, Any, Optional, Union
|
|
21
24
|
|
|
22
25
|
import pathspec
|
|
23
26
|
import tomli_w
|
|
24
27
|
import typer
|
|
25
28
|
|
|
29
|
+
from flwr.common.constant import FAB_ALLOWED_EXTENSIONS, FAB_DATE, FAB_HASH_TRUNCATION
|
|
30
|
+
|
|
26
31
|
from .config_utils import load_and_validate
|
|
27
|
-
from .utils import
|
|
32
|
+
from .utils import is_valid_project_name
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def write_to_zip(
|
|
36
|
+
zipfile_obj: zipfile.ZipFile, filename: str, contents: Union[bytes, str]
|
|
37
|
+
) -> zipfile.ZipFile:
|
|
38
|
+
"""Set a fixed date and write contents to a zip file."""
|
|
39
|
+
zip_info = zipfile.ZipInfo(filename)
|
|
40
|
+
zip_info.date_time = FAB_DATE
|
|
41
|
+
zipfile_obj.writestr(zip_info, contents)
|
|
42
|
+
return zipfile_obj
|
|
43
|
+
|
|
28
44
|
|
|
45
|
+
def get_fab_filename(conf: dict[str, Any], fab_hash: str) -> str:
|
|
46
|
+
"""Get the FAB filename based on the given config and FAB hash."""
|
|
47
|
+
publisher = conf["tool"]["flwr"]["app"]["publisher"]
|
|
48
|
+
name = conf["project"]["name"]
|
|
49
|
+
version = conf["project"]["version"].replace(".", "-")
|
|
50
|
+
fab_hash_truncated = fab_hash[:FAB_HASH_TRUNCATION]
|
|
51
|
+
return f"{publisher}.{name}.{version}.{fab_hash_truncated}.fab"
|
|
29
52
|
|
|
30
|
-
|
|
53
|
+
|
|
54
|
+
# pylint: disable=too-many-locals, too-many-statements
|
|
31
55
|
def build(
|
|
32
56
|
app: Annotated[
|
|
33
57
|
Optional[Path],
|
|
34
58
|
typer.Option(help="Path of the Flower App to bundle into a FAB"),
|
|
35
59
|
] = None,
|
|
36
|
-
) -> str:
|
|
60
|
+
) -> tuple[str, str]:
|
|
37
61
|
"""Build a Flower App into a Flower App Bundle (FAB).
|
|
38
62
|
|
|
39
63
|
You can run ``flwr build`` without any arguments to bundle the app located in the
|
|
@@ -85,16 +109,8 @@ def build(
|
|
|
85
109
|
# Load .gitignore rules if present
|
|
86
110
|
ignore_spec = _load_gitignore(app)
|
|
87
111
|
|
|
88
|
-
# Set the name of the zip file
|
|
89
|
-
fab_filename = (
|
|
90
|
-
f"{conf['tool']['flwr']['app']['publisher']}"
|
|
91
|
-
f".{conf['project']['name']}"
|
|
92
|
-
f".{conf['project']['version'].replace('.', '-')}.fab"
|
|
93
|
-
)
|
|
94
112
|
list_file_content = ""
|
|
95
113
|
|
|
96
|
-
allowed_extensions = {".py", ".toml", ".md"}
|
|
97
|
-
|
|
98
114
|
# Remove the 'federations' field from 'tool.flwr' if it exists
|
|
99
115
|
if (
|
|
100
116
|
"tool" in conf
|
|
@@ -105,38 +121,53 @@ def build(
|
|
|
105
121
|
|
|
106
122
|
toml_contents = tomli_w.dumps(conf)
|
|
107
123
|
|
|
108
|
-
with
|
|
109
|
-
|
|
124
|
+
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as temp_file:
|
|
125
|
+
temp_filename = temp_file.name
|
|
126
|
+
|
|
127
|
+
with zipfile.ZipFile(temp_filename, "w", zipfile.ZIP_DEFLATED) as fab_file:
|
|
128
|
+
write_to_zip(fab_file, "pyproject.toml", toml_contents)
|
|
110
129
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
files = [
|
|
130
|
+
# Continue with adding other files
|
|
131
|
+
all_files = [
|
|
114
132
|
f
|
|
115
|
-
for f in
|
|
116
|
-
if not ignore_spec.match_file(
|
|
117
|
-
and f !=
|
|
118
|
-
and
|
|
119
|
-
and f != "pyproject.toml" # Exclude the original pyproject.toml
|
|
133
|
+
for f in app.rglob("*")
|
|
134
|
+
if not ignore_spec.match_file(f)
|
|
135
|
+
and f.name != temp_filename
|
|
136
|
+
and f.suffix in FAB_ALLOWED_EXTENSIONS
|
|
137
|
+
and f.name != "pyproject.toml" # Exclude the original pyproject.toml
|
|
120
138
|
]
|
|
121
139
|
|
|
122
|
-
for
|
|
123
|
-
|
|
140
|
+
for file_path in all_files:
|
|
141
|
+
# Read the file content manually
|
|
142
|
+
with open(file_path, "rb") as f:
|
|
143
|
+
file_contents = f.read()
|
|
144
|
+
|
|
124
145
|
archive_path = file_path.relative_to(app)
|
|
125
|
-
fab_file
|
|
146
|
+
write_to_zip(fab_file, str(archive_path), file_contents)
|
|
126
147
|
|
|
127
148
|
# Calculate file info
|
|
128
|
-
sha256_hash =
|
|
149
|
+
sha256_hash = hashlib.sha256(file_contents).hexdigest()
|
|
129
150
|
file_size_bits = os.path.getsize(file_path) * 8 # size in bits
|
|
130
151
|
list_file_content += f"{archive_path},{sha256_hash},{file_size_bits}\n"
|
|
131
152
|
|
|
132
|
-
|
|
133
|
-
|
|
153
|
+
# Add CONTENT and CONTENT.jwt to the zip file
|
|
154
|
+
write_to_zip(fab_file, ".info/CONTENT", list_file_content)
|
|
155
|
+
|
|
156
|
+
# Get hash of FAB file
|
|
157
|
+
content = Path(temp_filename).read_bytes()
|
|
158
|
+
fab_hash = hashlib.sha256(content).hexdigest()
|
|
159
|
+
|
|
160
|
+
# Set the name of the zip file
|
|
161
|
+
fab_filename = get_fab_filename(conf, fab_hash)
|
|
162
|
+
|
|
163
|
+
# Once the temporary zip file is created, rename it to the final filename
|
|
164
|
+
shutil.move(temp_filename, fab_filename)
|
|
134
165
|
|
|
135
166
|
typer.secho(
|
|
136
167
|
f"🎊 Successfully built {fab_filename}", fg=typer.colors.GREEN, bold=True
|
|
137
168
|
)
|
|
138
169
|
|
|
139
|
-
return fab_filename
|
|
170
|
+
return fab_filename, fab_hash
|
|
140
171
|
|
|
141
172
|
|
|
142
173
|
def _load_gitignore(app: Path) -> pathspec.PathSpec:
|
flwr/cli/config_utils.py
CHANGED
|
@@ -90,6 +90,16 @@ def load_and_validate(
|
|
|
90
90
|
) -> tuple[Optional[dict[str, Any]], list[str], list[str]]:
|
|
91
91
|
"""Load and validate pyproject.toml as dict.
|
|
92
92
|
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
path : Optional[Path] (default: None)
|
|
96
|
+
The path of the Flower App config file to load. By default it
|
|
97
|
+
will try to use `pyproject.toml` inside the current directory.
|
|
98
|
+
check_module: bool (default: True)
|
|
99
|
+
Whether the validity of the Python module should be checked.
|
|
100
|
+
This requires the project to be installed in the currently
|
|
101
|
+
running environment. True by default.
|
|
102
|
+
|
|
93
103
|
Returns
|
|
94
104
|
-------
|
|
95
105
|
Tuple[Optional[config], List[str], List[str]]
|
flwr/cli/install.py
CHANGED
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower command line interface `install` command."""
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
import hashlib
|
|
18
18
|
import shutil
|
|
19
19
|
import subprocess
|
|
20
20
|
import tempfile
|
|
@@ -25,7 +25,8 @@ from typing import IO, Annotated, Optional, Union
|
|
|
25
25
|
|
|
26
26
|
import typer
|
|
27
27
|
|
|
28
|
-
from flwr.common.config import get_flwr_dir
|
|
28
|
+
from flwr.common.config import get_flwr_dir, get_metadata_from_config
|
|
29
|
+
from flwr.common.constant import FAB_HASH_TRUNCATION
|
|
29
30
|
|
|
30
31
|
from .config_utils import load_and_validate
|
|
31
32
|
from .utils import get_sha256_hash
|
|
@@ -91,9 +92,11 @@ def install_from_fab(
|
|
|
91
92
|
fab_name: Optional[str]
|
|
92
93
|
if isinstance(fab_file, bytes):
|
|
93
94
|
fab_file_archive = BytesIO(fab_file)
|
|
95
|
+
fab_hash = hashlib.sha256(fab_file).hexdigest()
|
|
94
96
|
fab_name = None
|
|
95
97
|
elif isinstance(fab_file, Path):
|
|
96
98
|
fab_file_archive = fab_file
|
|
99
|
+
fab_hash = hashlib.sha256(fab_file.read_bytes()).hexdigest()
|
|
97
100
|
fab_name = fab_file.stem
|
|
98
101
|
else:
|
|
99
102
|
raise ValueError("fab_file must be either a Path or bytes")
|
|
@@ -126,14 +129,16 @@ def install_from_fab(
|
|
|
126
129
|
shutil.rmtree(info_dir)
|
|
127
130
|
|
|
128
131
|
installed_path = validate_and_install(
|
|
129
|
-
tmpdir_path, fab_name, flwr_dir, skip_prompt
|
|
132
|
+
tmpdir_path, fab_hash, fab_name, flwr_dir, skip_prompt
|
|
130
133
|
)
|
|
131
134
|
|
|
132
135
|
return installed_path
|
|
133
136
|
|
|
134
137
|
|
|
138
|
+
# pylint: disable=too-many-locals
|
|
135
139
|
def validate_and_install(
|
|
136
140
|
project_dir: Path,
|
|
141
|
+
fab_hash: str,
|
|
137
142
|
fab_name: Optional[str],
|
|
138
143
|
flwr_dir: Optional[Path],
|
|
139
144
|
skip_prompt: bool = False,
|
|
@@ -149,28 +154,17 @@ def validate_and_install(
|
|
|
149
154
|
)
|
|
150
155
|
raise typer.Exit(code=1)
|
|
151
156
|
|
|
152
|
-
|
|
153
|
-
project_name =
|
|
154
|
-
|
|
157
|
+
version, fab_id = get_metadata_from_config(config)
|
|
158
|
+
publisher, project_name = fab_id.split("/")
|
|
159
|
+
config_metadata = (publisher, project_name, version, fab_hash)
|
|
155
160
|
|
|
156
|
-
if
|
|
157
|
-
fab_name
|
|
158
|
-
and fab_name != f"{publisher}.{project_name}.{version.replace('.', '-')}"
|
|
159
|
-
):
|
|
160
|
-
typer.secho(
|
|
161
|
-
"❌ FAB file has incorrect name. The file name must follow the format "
|
|
162
|
-
"`<publisher>.<project_name>.<version>.fab`.",
|
|
163
|
-
fg=typer.colors.RED,
|
|
164
|
-
bold=True,
|
|
165
|
-
)
|
|
166
|
-
raise typer.Exit(code=1)
|
|
161
|
+
if fab_name:
|
|
162
|
+
_validate_fab_and_config_metadata(fab_name, config_metadata)
|
|
167
163
|
|
|
168
164
|
install_dir: Path = (
|
|
169
165
|
(get_flwr_dir() if not flwr_dir else flwr_dir)
|
|
170
166
|
/ "apps"
|
|
171
|
-
/ publisher
|
|
172
|
-
/ project_name
|
|
173
|
-
/ version
|
|
167
|
+
/ f"{publisher}.{project_name}.{version}.{fab_hash[:FAB_HASH_TRUNCATION]}"
|
|
174
168
|
)
|
|
175
169
|
if install_dir.exists():
|
|
176
170
|
if skip_prompt:
|
|
@@ -226,3 +220,49 @@ def _verify_hashes(list_content: str, tmpdir: Path) -> bool:
|
|
|
226
220
|
if not file_path.exists() or get_sha256_hash(file_path) != hash_expected:
|
|
227
221
|
return False
|
|
228
222
|
return True
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _validate_fab_and_config_metadata(
|
|
226
|
+
fab_name: str, config_metadata: tuple[str, str, str, str]
|
|
227
|
+
) -> None:
|
|
228
|
+
"""Validate metadata from the FAB filename and config."""
|
|
229
|
+
publisher, project_name, version, fab_hash = config_metadata
|
|
230
|
+
|
|
231
|
+
fab_name = fab_name.removesuffix(".fab")
|
|
232
|
+
|
|
233
|
+
fab_publisher, fab_project_name, fab_version, fab_shorthash = fab_name.split(".")
|
|
234
|
+
fab_version = fab_version.replace("-", ".")
|
|
235
|
+
|
|
236
|
+
# Check FAB filename format
|
|
237
|
+
if (
|
|
238
|
+
f"{fab_publisher}.{fab_project_name}.{fab_version}"
|
|
239
|
+
!= f"{publisher}.{project_name}.{version}"
|
|
240
|
+
or len(fab_shorthash) != FAB_HASH_TRUNCATION # Verify hash length
|
|
241
|
+
):
|
|
242
|
+
typer.secho(
|
|
243
|
+
"❌ FAB file has incorrect name. The file name must follow the format "
|
|
244
|
+
"`<publisher>.<project_name>.<version>.<8hexchars>.fab`.",
|
|
245
|
+
fg=typer.colors.RED,
|
|
246
|
+
bold=True,
|
|
247
|
+
)
|
|
248
|
+
raise typer.Exit(code=1)
|
|
249
|
+
|
|
250
|
+
# Verify hash is a valid hexadecimal
|
|
251
|
+
try:
|
|
252
|
+
_ = int(fab_shorthash, 16)
|
|
253
|
+
except Exception as e:
|
|
254
|
+
typer.secho(
|
|
255
|
+
f"❌ FAB file has an invalid hexadecimal string `{fab_shorthash}`.",
|
|
256
|
+
fg=typer.colors.RED,
|
|
257
|
+
bold=True,
|
|
258
|
+
)
|
|
259
|
+
raise typer.Exit(code=1) from e
|
|
260
|
+
|
|
261
|
+
# Verify shorthash matches
|
|
262
|
+
if fab_shorthash != fab_hash[:FAB_HASH_TRUNCATION]:
|
|
263
|
+
typer.secho(
|
|
264
|
+
"❌ The hash in the FAB file name does not match the hash of the FAB.",
|
|
265
|
+
fg=typer.colors.RED,
|
|
266
|
+
bold=True,
|
|
267
|
+
)
|
|
268
|
+
raise typer.Exit(code=1)
|
flwr/cli/new/new.py
CHANGED
|
@@ -240,6 +240,8 @@ def new(
|
|
|
240
240
|
MlFramework.HUGGINGFACE.value,
|
|
241
241
|
MlFramework.MLX.value,
|
|
242
242
|
MlFramework.TENSORFLOW.value,
|
|
243
|
+
MlFramework.SKLEARN.value,
|
|
244
|
+
MlFramework.NUMPY.value,
|
|
243
245
|
]
|
|
244
246
|
if framework_str in frameworks_with_tasks:
|
|
245
247
|
files[f"{import_name}/task.py"] = {
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
|
-
from flwr.client import NumPyClient, ClientApp
|
|
5
|
-
from flwr.common import Context
|
|
6
4
|
|
|
5
|
+
from flwr.client import ClientApp, NumPyClient
|
|
6
|
+
from flwr.common import Context
|
|
7
7
|
from $import_name.task import (
|
|
8
8
|
evaluation,
|
|
9
9
|
get_params,
|
|
@@ -17,37 +17,31 @@ from $import_name.task import (
|
|
|
17
17
|
|
|
18
18
|
# Define Flower Client and client_fn
|
|
19
19
|
class FlowerClient(NumPyClient):
|
|
20
|
-
def __init__(self):
|
|
20
|
+
def __init__(self, input_dim):
|
|
21
21
|
self.train_x, self.train_y, self.test_x, self.test_y = load_data()
|
|
22
22
|
self.grad_fn = jax.grad(loss_fn)
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
self.params = load_model(model_shape)
|
|
26
|
-
|
|
27
|
-
def get_parameters(self, config):
|
|
28
|
-
return get_params(self.params)
|
|
29
|
-
|
|
30
|
-
def set_parameters(self, parameters):
|
|
31
|
-
set_params(self.params, parameters)
|
|
23
|
+
self.params = load_model((input_dim,))
|
|
32
24
|
|
|
33
25
|
def fit(self, parameters, config):
|
|
34
|
-
self.
|
|
26
|
+
set_params(self.params, parameters)
|
|
35
27
|
self.params, loss, num_examples = train(
|
|
36
28
|
self.params, self.grad_fn, self.train_x, self.train_y
|
|
37
29
|
)
|
|
38
|
-
|
|
39
|
-
return parameters, num_examples, {"loss": float(loss)}
|
|
30
|
+
return get_params(self.params), num_examples, {"loss": float(loss)}
|
|
40
31
|
|
|
41
32
|
def evaluate(self, parameters, config):
|
|
42
|
-
self.
|
|
33
|
+
set_params(self.params, parameters)
|
|
43
34
|
loss, num_examples = evaluation(
|
|
44
35
|
self.params, self.grad_fn, self.test_x, self.test_y
|
|
45
36
|
)
|
|
46
37
|
return float(loss), num_examples, {"loss": float(loss)}
|
|
47
38
|
|
|
39
|
+
|
|
48
40
|
def client_fn(context: Context):
|
|
41
|
+
input_dim = context.run_config["input-dim"]
|
|
42
|
+
|
|
49
43
|
# Return Client instance
|
|
50
|
-
return FlowerClient().to_client()
|
|
44
|
+
return FlowerClient(input_dim).to_client()
|
|
51
45
|
|
|
52
46
|
|
|
53
47
|
# Flower ClientApp
|
|
@@ -3,17 +3,18 @@
|
|
|
3
3
|
import mlx.core as mx
|
|
4
4
|
import mlx.nn as nn
|
|
5
5
|
import mlx.optimizers as optim
|
|
6
|
-
from flwr.client import NumPyClient, ClientApp
|
|
7
|
-
from flwr.common import Context
|
|
8
6
|
|
|
7
|
+
from flwr.client import ClientApp, NumPyClient
|
|
8
|
+
from flwr.common import Context
|
|
9
|
+
from flwr.common.config import UserConfig
|
|
9
10
|
from $import_name.task import (
|
|
11
|
+
MLP,
|
|
10
12
|
batch_iterate,
|
|
11
13
|
eval_fn,
|
|
12
14
|
get_params,
|
|
13
15
|
load_data,
|
|
14
16
|
loss_fn,
|
|
15
17
|
set_params,
|
|
16
|
-
MLP,
|
|
17
18
|
)
|
|
18
19
|
|
|
19
20
|
|
|
@@ -22,37 +23,24 @@ class FlowerClient(NumPyClient):
|
|
|
22
23
|
def __init__(
|
|
23
24
|
self,
|
|
24
25
|
data,
|
|
25
|
-
|
|
26
|
-
hidden_dim,
|
|
26
|
+
run_config: UserConfig,
|
|
27
27
|
num_classes,
|
|
28
|
-
batch_size,
|
|
29
|
-
learning_rate,
|
|
30
|
-
num_epochs,
|
|
31
28
|
):
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
self.num_epochs =
|
|
29
|
+
num_layers = run_config["num-layers"]
|
|
30
|
+
hidden_dim = run_config["hidden-dim"]
|
|
31
|
+
input_dim = run_config["input-dim"]
|
|
32
|
+
batch_size = run_config["batch-size"]
|
|
33
|
+
learning_rate = run_config["lr"]
|
|
34
|
+
self.num_epochs = run_config["local-epochs"]
|
|
38
35
|
|
|
39
36
|
self.train_images, self.train_labels, self.test_images, self.test_labels = data
|
|
40
|
-
self.model = MLP(
|
|
41
|
-
num_layers, self.train_images.shape[-1], hidden_dim, num_classes
|
|
42
|
-
)
|
|
37
|
+
self.model = MLP(num_layers, input_dim, hidden_dim, num_classes)
|
|
43
38
|
self.optimizer = optim.SGD(learning_rate=learning_rate)
|
|
44
39
|
self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
|
|
45
|
-
self.num_epochs = num_epochs
|
|
46
40
|
self.batch_size = batch_size
|
|
47
41
|
|
|
48
|
-
def get_parameters(self, config):
|
|
49
|
-
return get_params(self.model)
|
|
50
|
-
|
|
51
|
-
def set_parameters(self, parameters):
|
|
52
|
-
set_params(self.model, parameters)
|
|
53
|
-
|
|
54
42
|
def fit(self, parameters, config):
|
|
55
|
-
self.
|
|
43
|
+
set_params(self.model, parameters)
|
|
56
44
|
for _ in range(self.num_epochs):
|
|
57
45
|
for X, y in batch_iterate(
|
|
58
46
|
self.batch_size, self.train_images, self.train_labels
|
|
@@ -60,10 +48,10 @@ class FlowerClient(NumPyClient):
|
|
|
60
48
|
_, grads = self.loss_and_grad_fn(self.model, X, y)
|
|
61
49
|
self.optimizer.update(self.model, grads)
|
|
62
50
|
mx.eval(self.model.parameters(), self.optimizer.state)
|
|
63
|
-
return self.
|
|
51
|
+
return get_params(self.model), len(self.train_images), {}
|
|
64
52
|
|
|
65
53
|
def evaluate(self, parameters, config):
|
|
66
|
-
self.
|
|
54
|
+
set_params(self.model, parameters)
|
|
67
55
|
accuracy = eval_fn(self.model, self.test_images, self.test_labels)
|
|
68
56
|
loss = loss_fn(self.model, self.test_images, self.test_labels)
|
|
69
57
|
return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}
|
|
@@ -73,18 +61,10 @@ def client_fn(context: Context):
|
|
|
73
61
|
partition_id = context.node_config["partition-id"]
|
|
74
62
|
num_partitions = context.node_config["num-partitions"]
|
|
75
63
|
data = load_data(partition_id, num_partitions)
|
|
76
|
-
|
|
77
|
-
num_layers = context.run_config["num-layers"]
|
|
78
|
-
hidden_dim = context.run_config["hidden-dim"]
|
|
79
64
|
num_classes = 10
|
|
80
|
-
batch_size = context.run_config["batch-size"]
|
|
81
|
-
learning_rate = context.run_config["lr"]
|
|
82
|
-
num_epochs = context.run_config["local-epochs"]
|
|
83
65
|
|
|
84
66
|
# Return Client instance
|
|
85
|
-
return FlowerClient(
|
|
86
|
-
data, num_layers, hidden_dim, num_classes, batch_size, learning_rate, num_epochs
|
|
87
|
-
).to_client()
|
|
67
|
+
return FlowerClient(data, context.run_config, num_classes).to_client()
|
|
88
68
|
|
|
89
69
|
|
|
90
70
|
# Flower ClientApp
|
|
@@ -1,16 +1,15 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
from flwr.client import
|
|
3
|
+
from flwr.client import ClientApp, NumPyClient
|
|
4
4
|
from flwr.common import Context
|
|
5
|
-
|
|
5
|
+
from $import_name.task import get_dummy_model
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class FlowerClient(NumPyClient):
|
|
9
|
-
def get_parameters(self, config):
|
|
10
|
-
return [np.ones((1, 1))]
|
|
11
9
|
|
|
12
10
|
def fit(self, parameters, config):
|
|
13
|
-
|
|
11
|
+
model = get_dummy_model()
|
|
12
|
+
return [model], 1, {}
|
|
14
13
|
|
|
15
14
|
def evaluate(self, parameters, config):
|
|
16
15
|
return float(0.0), 1, {"accuracy": float(1.0)}
|
|
@@ -1,17 +1,10 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from flwr.client import NumPyClient, ClientApp
|
|
5
|
-
from flwr.common import Context
|
|
6
4
|
|
|
7
|
-
from
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
get_weights,
|
|
11
|
-
set_weights,
|
|
12
|
-
train,
|
|
13
|
-
test,
|
|
14
|
-
)
|
|
5
|
+
from flwr.client import ClientApp, NumPyClient
|
|
6
|
+
from flwr.common import Context
|
|
7
|
+
from $import_name.task import Net, get_weights, load_data, set_weights, test, train
|
|
15
8
|
|
|
16
9
|
|
|
17
10
|
# Define Flower Client and client_fn
|
|
@@ -32,7 +25,11 @@ class FlowerClient(NumPyClient):
|
|
|
32
25
|
self.local_epochs,
|
|
33
26
|
self.device,
|
|
34
27
|
)
|
|
35
|
-
return
|
|
28
|
+
return (
|
|
29
|
+
get_weights(self.net),
|
|
30
|
+
len(self.trainloader.dataset),
|
|
31
|
+
{"train_loss": train_loss},
|
|
32
|
+
)
|
|
36
33
|
|
|
37
34
|
def evaluate(self, parameters, config):
|
|
38
35
|
set_weights(self.net, parameters)
|
|
@@ -2,40 +2,17 @@
|
|
|
2
2
|
|
|
3
3
|
import warnings
|
|
4
4
|
|
|
5
|
-
import numpy as np
|
|
6
|
-
from flwr.client import NumPyClient, ClientApp
|
|
7
|
-
from flwr.common import Context
|
|
8
|
-
from flwr_datasets import FederatedDataset
|
|
9
|
-
from sklearn.linear_model import LogisticRegression
|
|
10
5
|
from sklearn.metrics import log_loss
|
|
11
6
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
return params
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def set_model_params(model, params):
|
|
25
|
-
model.coef_ = params[0]
|
|
26
|
-
if model.fit_intercept:
|
|
27
|
-
model.intercept_ = params[1]
|
|
28
|
-
return model
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def set_initial_params(model):
|
|
32
|
-
n_classes = 10 # MNIST has 10 classes
|
|
33
|
-
n_features = 784 # Number of features in dataset
|
|
34
|
-
model.classes_ = np.array([i for i in range(10)])
|
|
35
|
-
|
|
36
|
-
model.coef_ = np.zeros((n_classes, n_features))
|
|
37
|
-
if model.fit_intercept:
|
|
38
|
-
model.intercept_ = np.zeros((n_classes,))
|
|
7
|
+
from flwr.client import ClientApp, NumPyClient
|
|
8
|
+
from flwr.common import Context
|
|
9
|
+
from $import_name.task import (
|
|
10
|
+
get_model,
|
|
11
|
+
get_model_params,
|
|
12
|
+
load_data,
|
|
13
|
+
set_initial_params,
|
|
14
|
+
set_model_params,
|
|
15
|
+
)
|
|
39
16
|
|
|
40
17
|
|
|
41
18
|
class FlowerClient(NumPyClient):
|
|
@@ -46,9 +23,6 @@ class FlowerClient(NumPyClient):
|
|
|
46
23
|
self.y_train = y_train
|
|
47
24
|
self.y_test = y_test
|
|
48
25
|
|
|
49
|
-
def get_parameters(self, config):
|
|
50
|
-
return get_model_parameters(self.model)
|
|
51
|
-
|
|
52
26
|
def fit(self, parameters, config):
|
|
53
27
|
set_model_params(self.model, parameters)
|
|
54
28
|
|
|
@@ -57,7 +31,7 @@ class FlowerClient(NumPyClient):
|
|
|
57
31
|
warnings.simplefilter("ignore")
|
|
58
32
|
self.model.fit(self.X_train, self.y_train)
|
|
59
33
|
|
|
60
|
-
return
|
|
34
|
+
return get_model_params(self.model), len(self.X_train), {}
|
|
61
35
|
|
|
62
36
|
def evaluate(self, parameters, config):
|
|
63
37
|
set_model_params(self.model, parameters)
|
|
@@ -71,21 +45,13 @@ class FlowerClient(NumPyClient):
|
|
|
71
45
|
def client_fn(context: Context):
|
|
72
46
|
partition_id = context.node_config["partition-id"]
|
|
73
47
|
num_partitions = context.node_config["num-partitions"]
|
|
74
|
-
fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
|
|
75
|
-
dataset = fds.load_partition(partition_id, "train").with_format("numpy")
|
|
76
|
-
|
|
77
|
-
X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
|
|
78
48
|
|
|
79
|
-
|
|
80
|
-
X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
|
|
81
|
-
y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]
|
|
49
|
+
X_train, X_test, y_train, y_test = load_data(partition_id, num_partitions)
|
|
82
50
|
|
|
83
51
|
# Create LogisticRegression Model
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
warm_start=True, # prevent refreshing weights when fitting
|
|
88
|
-
)
|
|
52
|
+
penalty = context.run_config["penalty"]
|
|
53
|
+
local_epochs = context.run_config["local-epochs"]
|
|
54
|
+
model = get_model(penalty, local_epochs)
|
|
89
55
|
|
|
90
56
|
# Setting initial parameters, akin to model.compile for keras models
|
|
91
57
|
set_initial_params(model)
|
|
@@ -1,16 +1,22 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
from flwr.common import Context
|
|
4
|
-
from flwr.server.strategy import FedAvg
|
|
3
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
5
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
|
+
from flwr.server.strategy import FedAvg
|
|
6
|
+
from $import_name.task import get_params, load_model
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def server_fn(context: Context):
|
|
9
10
|
# Read from config
|
|
10
11
|
num_rounds = context.run_config["num-server-rounds"]
|
|
12
|
+
input_dim = context.run_config["input-dim"]
|
|
13
|
+
|
|
14
|
+
# Initialize global model
|
|
15
|
+
params = get_params(load_model((input_dim,)))
|
|
16
|
+
initial_parameters = ndarrays_to_parameters(params)
|
|
11
17
|
|
|
12
18
|
# Define strategy
|
|
13
|
-
strategy = FedAvg()
|
|
19
|
+
strategy = FedAvg(initial_parameters=initial_parameters)
|
|
14
20
|
config = ServerConfig(num_rounds=num_rounds)
|
|
15
21
|
|
|
16
22
|
return ServerAppComponents(strategy=strategy, config=config)
|
|
@@ -1,16 +1,27 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
from flwr.common import Context
|
|
3
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
4
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
5
|
from flwr.server.strategy import FedAvg
|
|
6
|
+
from $import_name.task import MLP, get_params
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def server_fn(context: Context):
|
|
9
10
|
# Read from config
|
|
10
11
|
num_rounds = context.run_config["num-server-rounds"]
|
|
11
12
|
|
|
13
|
+
num_classes = 10
|
|
14
|
+
num_layers = context.run_config["num-layers"]
|
|
15
|
+
input_dim = context.run_config["input-dim"]
|
|
16
|
+
hidden_dim = context.run_config["hidden-dim"]
|
|
17
|
+
|
|
18
|
+
# Initialize global model
|
|
19
|
+
model = MLP(num_layers, input_dim, hidden_dim, num_classes)
|
|
20
|
+
params = get_params(model)
|
|
21
|
+
initial_parameters = ndarrays_to_parameters(params)
|
|
22
|
+
|
|
12
23
|
# Define strategy
|
|
13
|
-
strategy = FedAvg()
|
|
24
|
+
strategy = FedAvg(initial_parameters=initial_parameters)
|
|
14
25
|
config = ServerConfig(num_rounds=num_rounds)
|
|
15
26
|
|
|
16
27
|
return ServerAppComponents(strategy=strategy, config=config)
|