flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.11.0.dev20240724__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +16 -2
- flwr/cli/config_utils.py +47 -27
- flwr/cli/install.py +17 -1
- flwr/cli/new/new.py +32 -21
- flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +15 -5
- flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +36 -13
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +16 -5
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +25 -5
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +22 -19
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +16 -8
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +12 -7
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +16 -8
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -10
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
- flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +14 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +14 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -3
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
- flwr/cli/run/run.py +133 -54
- flwr/client/app.py +56 -24
- flwr/client/client_app.py +28 -8
- flwr/client/grpc_adapter_client/connection.py +3 -2
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +17 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/node_state.py +59 -12
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/connection.py +19 -8
- flwr/client/supernode/app.py +39 -39
- flwr/client/typing.py +2 -2
- flwr/common/config.py +92 -2
- flwr/common/constant.py +3 -0
- flwr/common/context.py +24 -9
- flwr/common/logger.py +25 -0
- flwr/common/object_ref.py +84 -21
- flwr/common/serde.py +45 -0
- flwr/common/telemetry.py +17 -0
- flwr/common/typing.py +5 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +24 -19
- flwr/proto/driver_pb2.pyi +21 -1
- flwr/proto/exec_pb2.py +20 -11
- flwr/proto/exec_pb2.pyi +41 -1
- flwr/proto/run_pb2.py +12 -7
- flwr/proto/run_pb2.pyi +22 -1
- flwr/proto/task_pb2.py +7 -8
- flwr/server/__init__.py +2 -0
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/grpc_driver.py +82 -140
- flwr/server/run_serverapp.py +40 -18
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/driver/driver_servicer.py +18 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +149 -117
- flwr/server/superlink/state/in_memory_state.py +11 -3
- flwr/server/superlink/state/sqlite_state.py +23 -8
- flwr/server/superlink/state/state.py +7 -2
- flwr/server/typing.py +2 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +4 -3
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
- flwr/simulation/run_simulation.py +269 -70
- flwr/superexec/app.py +17 -11
- flwr/superexec/deployment.py +111 -35
- flwr/superexec/exec_grpc.py +5 -1
- flwr/superexec/exec_servicer.py +6 -1
- flwr/superexec/executor.py +21 -0
- flwr/superexec/simulation.py +181 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/METADATA +3 -2
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/RECORD +97 -91
- flwr/cli/new/templates/app/code/server.hf.py.tpl +0 -17
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +0 -37
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/entry_points.txt +0 -0
flwr/cli/build.py
CHANGED
|
@@ -20,6 +20,7 @@ from pathlib import Path
|
|
|
20
20
|
from typing import Optional
|
|
21
21
|
|
|
22
22
|
import pathspec
|
|
23
|
+
import tomli_w
|
|
23
24
|
import typer
|
|
24
25
|
from typing_extensions import Annotated
|
|
25
26
|
|
|
@@ -85,7 +86,7 @@ def build(
|
|
|
85
86
|
|
|
86
87
|
# Set the name of the zip file
|
|
87
88
|
fab_filename = (
|
|
88
|
-
f"{conf['
|
|
89
|
+
f"{conf['tool']['flwr']['app']['publisher']}"
|
|
89
90
|
f".{directory.name}"
|
|
90
91
|
f".{conf['project']['version'].replace('.', '-')}.fab"
|
|
91
92
|
)
|
|
@@ -93,15 +94,28 @@ def build(
|
|
|
93
94
|
|
|
94
95
|
allowed_extensions = {".py", ".toml", ".md"}
|
|
95
96
|
|
|
97
|
+
# Remove the 'federations' field from 'tool.flwr' if it exists
|
|
98
|
+
if (
|
|
99
|
+
"tool" in conf
|
|
100
|
+
and "flwr" in conf["tool"]
|
|
101
|
+
and "federations" in conf["tool"]["flwr"]
|
|
102
|
+
):
|
|
103
|
+
del conf["tool"]["flwr"]["federations"]
|
|
104
|
+
|
|
105
|
+
toml_contents = tomli_w.dumps(conf)
|
|
106
|
+
|
|
96
107
|
with zipfile.ZipFile(fab_filename, "w", zipfile.ZIP_DEFLATED) as fab_file:
|
|
108
|
+
fab_file.writestr("pyproject.toml", toml_contents)
|
|
109
|
+
|
|
110
|
+
# Continue with adding other files
|
|
97
111
|
for root, _, files in os.walk(directory, topdown=True):
|
|
98
|
-
# Filter directories and files based on .gitignore
|
|
99
112
|
files = [
|
|
100
113
|
f
|
|
101
114
|
for f in files
|
|
102
115
|
if not ignore_spec.match_file(Path(root) / f)
|
|
103
116
|
and f != fab_filename
|
|
104
117
|
and Path(f).suffix in allowed_extensions
|
|
118
|
+
and f != "pyproject.toml" # Exclude the original pyproject.toml
|
|
105
119
|
]
|
|
106
120
|
|
|
107
121
|
for file in files:
|
flwr/cli/config_utils.py
CHANGED
|
@@ -17,11 +17,12 @@
|
|
|
17
17
|
import zipfile
|
|
18
18
|
from io import BytesIO
|
|
19
19
|
from pathlib import Path
|
|
20
|
-
from typing import IO, Any, Dict, List, Optional, Tuple, Union
|
|
20
|
+
from typing import IO, Any, Dict, List, Optional, Tuple, Union, get_args
|
|
21
21
|
|
|
22
22
|
import tomli
|
|
23
23
|
|
|
24
24
|
from flwr.common import object_ref
|
|
25
|
+
from flwr.common.typing import UserConfigValue
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
|
|
@@ -60,7 +61,7 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
|
|
|
60
61
|
|
|
61
62
|
return (
|
|
62
63
|
conf["project"]["version"],
|
|
63
|
-
f"{conf['
|
|
64
|
+
f"{conf['tool']['flwr']['app']['publisher']}/{conf['project']['name']}",
|
|
64
65
|
)
|
|
65
66
|
|
|
66
67
|
|
|
@@ -76,6 +77,9 @@ def load_and_validate(
|
|
|
76
77
|
A tuple with the optional config in case it exists and is valid
|
|
77
78
|
and associated errors and warnings.
|
|
78
79
|
"""
|
|
80
|
+
if path is None:
|
|
81
|
+
path = Path.cwd() / "pyproject.toml"
|
|
82
|
+
|
|
79
83
|
config = load(path)
|
|
80
84
|
|
|
81
85
|
if config is None:
|
|
@@ -85,7 +89,7 @@ def load_and_validate(
|
|
|
85
89
|
]
|
|
86
90
|
return (None, errors, [])
|
|
87
91
|
|
|
88
|
-
is_valid, errors, warnings = validate(config, check_module)
|
|
92
|
+
is_valid, errors, warnings = validate(config, check_module, path.parent)
|
|
89
93
|
|
|
90
94
|
if not is_valid:
|
|
91
95
|
return (None, errors, warnings)
|
|
@@ -93,14 +97,8 @@ def load_and_validate(
|
|
|
93
97
|
return (config, errors, warnings)
|
|
94
98
|
|
|
95
99
|
|
|
96
|
-
def load(
|
|
100
|
+
def load(toml_path: Path) -> Optional[Dict[str, Any]]:
|
|
97
101
|
"""Load pyproject.toml and return as dict."""
|
|
98
|
-
if path is None:
|
|
99
|
-
cur_dir = Path.cwd()
|
|
100
|
-
toml_path = cur_dir / "pyproject.toml"
|
|
101
|
-
else:
|
|
102
|
-
toml_path = path
|
|
103
|
-
|
|
104
102
|
if not toml_path.is_file():
|
|
105
103
|
return None
|
|
106
104
|
|
|
@@ -108,6 +106,17 @@ def load(path: Optional[Path] = None) -> Optional[Dict[str, Any]]:
|
|
|
108
106
|
return load_from_string(toml_file.read())
|
|
109
107
|
|
|
110
108
|
|
|
109
|
+
def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None:
|
|
110
|
+
for key, value in config_dict.items():
|
|
111
|
+
if isinstance(value, dict):
|
|
112
|
+
_validate_run_config(config_dict[key], errors)
|
|
113
|
+
elif not isinstance(value, get_args(UserConfigValue)):
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"The value for key {key} needs to be of type `int`, `float`, "
|
|
116
|
+
"`bool, `str`, or a `dict` of those.",
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
111
120
|
# pylint: disable=too-many-branches
|
|
112
121
|
def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
|
|
113
122
|
"""Validate pyproject.toml fields."""
|
|
@@ -128,24 +137,36 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]
|
|
|
128
137
|
if "authors" not in config["project"]:
|
|
129
138
|
warnings.append('Recommended property "authors" missing in [project]')
|
|
130
139
|
|
|
131
|
-
if
|
|
132
|
-
|
|
140
|
+
if (
|
|
141
|
+
"tool" not in config
|
|
142
|
+
or "flwr" not in config["tool"]
|
|
143
|
+
or "app" not in config["tool"]["flwr"]
|
|
144
|
+
):
|
|
145
|
+
errors.append("Missing [tool.flwr.app] section")
|
|
133
146
|
else:
|
|
134
|
-
if "publisher" not in config["
|
|
135
|
-
errors.append('Property "publisher" missing in [
|
|
136
|
-
if "
|
|
137
|
-
|
|
147
|
+
if "publisher" not in config["tool"]["flwr"]["app"]:
|
|
148
|
+
errors.append('Property "publisher" missing in [tool.flwr.app]')
|
|
149
|
+
if "config" in config["tool"]["flwr"]["app"]:
|
|
150
|
+
_validate_run_config(config["tool"]["flwr"]["app"]["config"], errors)
|
|
151
|
+
if "components" not in config["tool"]["flwr"]["app"]:
|
|
152
|
+
errors.append("Missing [tool.flwr.app.components] section")
|
|
138
153
|
else:
|
|
139
|
-
if "serverapp" not in config["
|
|
140
|
-
errors.append(
|
|
141
|
-
|
|
142
|
-
|
|
154
|
+
if "serverapp" not in config["tool"]["flwr"]["app"]["components"]:
|
|
155
|
+
errors.append(
|
|
156
|
+
'Property "serverapp" missing in [tool.flwr.app.components]'
|
|
157
|
+
)
|
|
158
|
+
if "clientapp" not in config["tool"]["flwr"]["app"]["components"]:
|
|
159
|
+
errors.append(
|
|
160
|
+
'Property "clientapp" missing in [tool.flwr.app.components]'
|
|
161
|
+
)
|
|
143
162
|
|
|
144
163
|
return len(errors) == 0, errors, warnings
|
|
145
164
|
|
|
146
165
|
|
|
147
166
|
def validate(
|
|
148
|
-
config: Dict[str, Any],
|
|
167
|
+
config: Dict[str, Any],
|
|
168
|
+
check_module: bool = True,
|
|
169
|
+
project_dir: Optional[Union[str, Path]] = None,
|
|
149
170
|
) -> Tuple[bool, List[str], List[str]]:
|
|
150
171
|
"""Validate pyproject.toml."""
|
|
151
172
|
is_valid, errors, warnings = validate_fields(config)
|
|
@@ -154,16 +175,15 @@ def validate(
|
|
|
154
175
|
return False, errors, warnings
|
|
155
176
|
|
|
156
177
|
# Validate serverapp
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
178
|
+
serverapp_ref = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
|
179
|
+
is_valid, reason = object_ref.validate(serverapp_ref, check_module, project_dir)
|
|
180
|
+
|
|
160
181
|
if not is_valid and isinstance(reason, str):
|
|
161
182
|
return False, [reason], []
|
|
162
183
|
|
|
163
184
|
# Validate clientapp
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
)
|
|
185
|
+
clientapp_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
|
|
186
|
+
is_valid, reason = object_ref.validate(clientapp_ref, check_module, project_dir)
|
|
167
187
|
|
|
168
188
|
if not is_valid and isinstance(reason, str):
|
|
169
189
|
return False, [reason], []
|
flwr/cli/install.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import shutil
|
|
19
|
+
import subprocess
|
|
19
20
|
import tempfile
|
|
20
21
|
import zipfile
|
|
21
22
|
from io import BytesIO
|
|
@@ -149,7 +150,7 @@ def validate_and_install(
|
|
|
149
150
|
)
|
|
150
151
|
raise typer.Exit(code=1)
|
|
151
152
|
|
|
152
|
-
publisher = config["
|
|
153
|
+
publisher = config["tool"]["flwr"]["app"]["publisher"]
|
|
153
154
|
project_name = config["project"]["name"]
|
|
154
155
|
version = config["project"]["version"]
|
|
155
156
|
|
|
@@ -192,6 +193,21 @@ def validate_and_install(
|
|
|
192
193
|
else:
|
|
193
194
|
shutil.copy2(item, install_dir / item.name)
|
|
194
195
|
|
|
196
|
+
try:
|
|
197
|
+
subprocess.run(
|
|
198
|
+
["pip", "install", "-e", install_dir, "--no-deps"],
|
|
199
|
+
capture_output=True,
|
|
200
|
+
text=True,
|
|
201
|
+
check=True,
|
|
202
|
+
)
|
|
203
|
+
except subprocess.CalledProcessError as e:
|
|
204
|
+
typer.secho(
|
|
205
|
+
f"❌ Failed to `pip install` package(s) from {install_dir}:\n{e.stderr}",
|
|
206
|
+
fg=typer.colors.RED,
|
|
207
|
+
bold=True,
|
|
208
|
+
)
|
|
209
|
+
raise typer.Exit(code=1) from e
|
|
210
|
+
|
|
195
211
|
typer.secho(
|
|
196
212
|
f"🎊 Successfully installed {project_name} to {install_dir}.",
|
|
197
213
|
fg=typer.colors.GREEN,
|
flwr/cli/new/new.py
CHANGED
|
@@ -14,9 +14,9 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower command line interface `new` command."""
|
|
16
16
|
|
|
17
|
-
import os
|
|
18
17
|
import re
|
|
19
18
|
from enum import Enum
|
|
19
|
+
from pathlib import Path
|
|
20
20
|
from string import Template
|
|
21
21
|
from typing import Dict, Optional
|
|
22
22
|
|
|
@@ -38,7 +38,7 @@ class MlFramework(str, Enum):
|
|
|
38
38
|
PYTORCH = "PyTorch"
|
|
39
39
|
TENSORFLOW = "TensorFlow"
|
|
40
40
|
JAX = "JAX"
|
|
41
|
-
HUGGINGFACE = "
|
|
41
|
+
HUGGINGFACE = "HuggingFace"
|
|
42
42
|
MLX = "MLX"
|
|
43
43
|
SKLEARN = "sklearn"
|
|
44
44
|
FLOWERTUNE = "FlowerTune"
|
|
@@ -59,10 +59,10 @@ class TemplateNotFound(Exception):
|
|
|
59
59
|
|
|
60
60
|
def load_template(name: str) -> str:
|
|
61
61
|
"""Load template from template directory and return as text."""
|
|
62
|
-
tpl_dir =
|
|
63
|
-
tpl_file_path =
|
|
62
|
+
tpl_dir = (Path(__file__).parent / "templates").absolute()
|
|
63
|
+
tpl_file_path = tpl_dir / name
|
|
64
64
|
|
|
65
|
-
if not
|
|
65
|
+
if not tpl_file_path.is_file():
|
|
66
66
|
raise TemplateNotFound(f"Template '{name}' not found")
|
|
67
67
|
|
|
68
68
|
with open(tpl_file_path, encoding="utf-8") as tpl_file:
|
|
@@ -78,14 +78,13 @@ def render_template(template: str, data: Dict[str, str]) -> str:
|
|
|
78
78
|
return tpl.template
|
|
79
79
|
|
|
80
80
|
|
|
81
|
-
def create_file(file_path:
|
|
81
|
+
def create_file(file_path: Path, content: str) -> None:
|
|
82
82
|
"""Create file including all nessecary directories and write content into file."""
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
f.write(content)
|
|
83
|
+
file_path.parent.mkdir(exist_ok=True)
|
|
84
|
+
file_path.write_text(content)
|
|
86
85
|
|
|
87
86
|
|
|
88
|
-
def render_and_create(file_path:
|
|
87
|
+
def render_and_create(file_path: Path, template: str, context: Dict[str, str]) -> None:
|
|
89
88
|
"""Render template and write to file."""
|
|
90
89
|
content = render_template(template, context)
|
|
91
90
|
create_file(file_path, content)
|
|
@@ -117,6 +116,21 @@ def new(
|
|
|
117
116
|
default=sanitize_project_name(project_name),
|
|
118
117
|
)
|
|
119
118
|
|
|
119
|
+
# Set project directory path
|
|
120
|
+
package_name = re.sub(r"[-_.]+", "-", project_name).lower()
|
|
121
|
+
import_name = package_name.replace("-", "_")
|
|
122
|
+
project_dir = Path.cwd() / package_name
|
|
123
|
+
|
|
124
|
+
if project_dir.exists():
|
|
125
|
+
if not typer.confirm(
|
|
126
|
+
typer.style(
|
|
127
|
+
f"\n💬 {project_name} already exists, do you want to override it?",
|
|
128
|
+
fg=typer.colors.MAGENTA,
|
|
129
|
+
bold=True,
|
|
130
|
+
)
|
|
131
|
+
):
|
|
132
|
+
return
|
|
133
|
+
|
|
120
134
|
if username is None:
|
|
121
135
|
username = prompt_text("Please provide your Flower username")
|
|
122
136
|
|
|
@@ -136,6 +150,7 @@ def new(
|
|
|
136
150
|
|
|
137
151
|
framework_str = framework_str.lower()
|
|
138
152
|
|
|
153
|
+
llm_challenge_str = None
|
|
139
154
|
if framework_str == "flowertune":
|
|
140
155
|
llm_challenge_value = prompt_options(
|
|
141
156
|
"Please select LLM challenge by typing in the number",
|
|
@@ -157,12 +172,6 @@ def new(
|
|
|
157
172
|
)
|
|
158
173
|
)
|
|
159
174
|
|
|
160
|
-
# Set project directory path
|
|
161
|
-
cwd = os.getcwd()
|
|
162
|
-
package_name = re.sub(r"[-_.]+", "-", project_name).lower()
|
|
163
|
-
import_name = package_name.replace("-", "_")
|
|
164
|
-
project_dir = os.path.join(cwd, package_name)
|
|
165
|
-
|
|
166
175
|
context = {
|
|
167
176
|
"project_name": project_name,
|
|
168
177
|
"package_name": package_name,
|
|
@@ -171,7 +180,7 @@ def new(
|
|
|
171
180
|
}
|
|
172
181
|
|
|
173
182
|
# List of files to render
|
|
174
|
-
if
|
|
183
|
+
if llm_challenge_str:
|
|
175
184
|
files = {
|
|
176
185
|
".gitignore": {"template": "app/.gitignore.tpl"},
|
|
177
186
|
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
|
|
@@ -228,10 +237,10 @@ def new(
|
|
|
228
237
|
"README.md": {"template": "app/README.md.tpl"},
|
|
229
238
|
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
|
|
230
239
|
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
|
|
231
|
-
f"{import_name}/
|
|
240
|
+
f"{import_name}/server_app.py": {
|
|
232
241
|
"template": f"app/code/server.{framework_str}.py.tpl"
|
|
233
242
|
},
|
|
234
|
-
f"{import_name}/
|
|
243
|
+
f"{import_name}/client_app.py": {
|
|
235
244
|
"template": f"app/code/client.{framework_str}.py.tpl"
|
|
236
245
|
},
|
|
237
246
|
}
|
|
@@ -251,7 +260,7 @@ def new(
|
|
|
251
260
|
|
|
252
261
|
for file_path, value in files.items():
|
|
253
262
|
render_and_create(
|
|
254
|
-
file_path=
|
|
263
|
+
file_path=project_dir / file_path,
|
|
255
264
|
template=value["template"],
|
|
256
265
|
context=context,
|
|
257
266
|
)
|
|
@@ -264,9 +273,11 @@ def new(
|
|
|
264
273
|
bold=True,
|
|
265
274
|
)
|
|
266
275
|
)
|
|
276
|
+
|
|
277
|
+
_add = " huggingface-cli login\n" if framework_str == "flowertune" else ""
|
|
267
278
|
print(
|
|
268
279
|
typer.style(
|
|
269
|
-
f" cd {package_name}\n" + " pip install -e .\n flwr run\n",
|
|
280
|
+
f" cd {package_name}\n" + " pip install -e .\n" + _add + " flwr run\n",
|
|
270
281
|
fg=typer.colors.BRIGHT_CYAN,
|
|
271
282
|
bold=True,
|
|
272
283
|
)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""$project_name: A Flower / HuggingFace Transformers app."""
|
|
2
2
|
|
|
3
3
|
from flwr.client import ClientApp, NumPyClient
|
|
4
|
+
from flwr.common import Context
|
|
4
5
|
from transformers import AutoModelForSequenceClassification
|
|
5
6
|
|
|
6
7
|
from $import_name.task import (
|
|
@@ -16,10 +17,11 @@ from $import_name.task import (
|
|
|
16
17
|
|
|
17
18
|
# Flower client
|
|
18
19
|
class FlowerClient(NumPyClient):
|
|
19
|
-
def __init__(self, net, trainloader, testloader):
|
|
20
|
+
def __init__(self, net, trainloader, testloader, local_epochs):
|
|
20
21
|
self.net = net
|
|
21
22
|
self.trainloader = trainloader
|
|
22
23
|
self.testloader = testloader
|
|
24
|
+
self.local_epochs = local_epochs
|
|
23
25
|
|
|
24
26
|
def get_parameters(self, config):
|
|
25
27
|
return get_weights(self.net)
|
|
@@ -29,7 +31,11 @@ class FlowerClient(NumPyClient):
|
|
|
29
31
|
|
|
30
32
|
def fit(self, parameters, config):
|
|
31
33
|
self.set_parameters(parameters)
|
|
32
|
-
train(
|
|
34
|
+
train(
|
|
35
|
+
self.net,
|
|
36
|
+
self.trainloader,
|
|
37
|
+
epochs=self.local_epochs,
|
|
38
|
+
)
|
|
33
39
|
return self.get_parameters(config={}), len(self.trainloader), {}
|
|
34
40
|
|
|
35
41
|
def evaluate(self, parameters, config):
|
|
@@ -38,15 +44,19 @@ class FlowerClient(NumPyClient):
|
|
|
38
44
|
return float(loss), len(self.testloader), {"accuracy": accuracy}
|
|
39
45
|
|
|
40
46
|
|
|
41
|
-
def client_fn(
|
|
47
|
+
def client_fn(context: Context):
|
|
42
48
|
# Load model and data
|
|
43
49
|
net = AutoModelForSequenceClassification.from_pretrained(
|
|
44
50
|
CHECKPOINT, num_labels=2
|
|
45
51
|
).to(DEVICE)
|
|
46
|
-
|
|
52
|
+
|
|
53
|
+
partition_id = context.node_config["partition-id"]
|
|
54
|
+
num_partitions = context.node_config["num-partitions"]
|
|
55
|
+
trainloader, valloader = load_data(partition_id, num_partitions)
|
|
56
|
+
local_epochs = context.run_config["local-epochs"]
|
|
47
57
|
|
|
48
58
|
# Return Client instance
|
|
49
|
-
return FlowerClient(net, trainloader, valloader).to_client()
|
|
59
|
+
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
|
|
50
60
|
|
|
51
61
|
|
|
52
62
|
# Flower ClientApp
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
4
|
from flwr.client import NumPyClient, ClientApp
|
|
5
|
+
from flwr.common import Context
|
|
5
6
|
|
|
6
7
|
from $import_name.task import (
|
|
7
8
|
evaluation,
|
|
@@ -44,7 +45,7 @@ class FlowerClient(NumPyClient):
|
|
|
44
45
|
)
|
|
45
46
|
return float(loss), num_examples, {"loss": float(loss)}
|
|
46
47
|
|
|
47
|
-
def client_fn(
|
|
48
|
+
def client_fn(context: Context):
|
|
48
49
|
# Return Client instance
|
|
49
50
|
return FlowerClient().to_client()
|
|
50
51
|
|
|
@@ -4,6 +4,7 @@ import mlx.core as mx
|
|
|
4
4
|
import mlx.nn as nn
|
|
5
5
|
import mlx.optimizers as optim
|
|
6
6
|
from flwr.client import NumPyClient, ClientApp
|
|
7
|
+
from flwr.common import Context
|
|
7
8
|
|
|
8
9
|
from $import_name.task import (
|
|
9
10
|
batch_iterate,
|
|
@@ -18,18 +19,29 @@ from $import_name.task import (
|
|
|
18
19
|
|
|
19
20
|
# Define Flower Client and client_fn
|
|
20
21
|
class FlowerClient(NumPyClient):
|
|
21
|
-
def __init__(
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
data,
|
|
25
|
+
num_layers,
|
|
26
|
+
hidden_dim,
|
|
27
|
+
num_classes,
|
|
28
|
+
batch_size,
|
|
29
|
+
learning_rate,
|
|
30
|
+
num_epochs,
|
|
31
|
+
):
|
|
32
|
+
self.num_layers = num_layers
|
|
33
|
+
self.hidden_dim = hidden_dim
|
|
34
|
+
self.num_classes = num_classes
|
|
35
|
+
self.batch_size = batch_size
|
|
36
|
+
self.learning_rate = learning_rate
|
|
37
|
+
self.num_epochs = num_epochs
|
|
28
38
|
|
|
29
39
|
self.train_images, self.train_labels, self.test_images, self.test_labels = data
|
|
30
|
-
self.model = MLP(
|
|
31
|
-
|
|
32
|
-
|
|
40
|
+
self.model = MLP(
|
|
41
|
+
num_layers, self.train_images.shape[-1], hidden_dim, num_classes
|
|
42
|
+
)
|
|
43
|
+
self.optimizer = optim.SGD(learning_rate=learning_rate)
|
|
44
|
+
self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
|
|
33
45
|
self.num_epochs = num_epochs
|
|
34
46
|
self.batch_size = batch_size
|
|
35
47
|
|
|
@@ -57,11 +69,22 @@ class FlowerClient(NumPyClient):
|
|
|
57
69
|
return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}
|
|
58
70
|
|
|
59
71
|
|
|
60
|
-
def client_fn(
|
|
61
|
-
|
|
72
|
+
def client_fn(context: Context):
|
|
73
|
+
partition_id = context.node_config["partition-id"]
|
|
74
|
+
num_partitions = context.node_config["num-partitions"]
|
|
75
|
+
data = load_data(partition_id, num_partitions)
|
|
76
|
+
|
|
77
|
+
num_layers = context.run_config["num-layers"]
|
|
78
|
+
hidden_dim = context.run_config["hidden-dim"]
|
|
79
|
+
num_classes = 10
|
|
80
|
+
batch_size = context.run_config["batch-size"]
|
|
81
|
+
learning_rate = context.run_config["lr"]
|
|
82
|
+
num_epochs = context.run_config["local-epochs"]
|
|
62
83
|
|
|
63
84
|
# Return Client instance
|
|
64
|
-
return FlowerClient(
|
|
85
|
+
return FlowerClient(
|
|
86
|
+
data, num_layers, hidden_dim, num_classes, batch_size, learning_rate, num_epochs
|
|
87
|
+
).to_client()
|
|
65
88
|
|
|
66
89
|
|
|
67
90
|
# Flower ClientApp
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""$project_name: A Flower / NumPy app."""
|
|
2
2
|
|
|
3
3
|
from flwr.client import NumPyClient, ClientApp
|
|
4
|
+
from flwr.common import Context
|
|
4
5
|
import numpy as np
|
|
5
6
|
|
|
6
7
|
|
|
@@ -15,7 +16,7 @@ class FlowerClient(NumPyClient):
|
|
|
15
16
|
return float(0.0), 1, {"accuracy": float(1.0)}
|
|
16
17
|
|
|
17
18
|
|
|
18
|
-
def client_fn(
|
|
19
|
+
def client_fn(context: Context):
|
|
19
20
|
return FlowerClient().to_client()
|
|
20
21
|
|
|
21
22
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""$project_name: A Flower / PyTorch app."""
|
|
2
2
|
|
|
3
3
|
from flwr.client import NumPyClient, ClientApp
|
|
4
|
+
from flwr.common import Context
|
|
4
5
|
|
|
5
6
|
from $import_name.task import (
|
|
6
7
|
Net,
|
|
@@ -15,14 +16,21 @@ from $import_name.task import (
|
|
|
15
16
|
|
|
16
17
|
# Define Flower Client and client_fn
|
|
17
18
|
class FlowerClient(NumPyClient):
|
|
18
|
-
def __init__(self, net, trainloader, valloader):
|
|
19
|
+
def __init__(self, net, trainloader, valloader, local_epochs):
|
|
19
20
|
self.net = net
|
|
20
21
|
self.trainloader = trainloader
|
|
21
22
|
self.valloader = valloader
|
|
23
|
+
self.local_epochs = local_epochs
|
|
22
24
|
|
|
23
25
|
def fit(self, parameters, config):
|
|
24
26
|
set_weights(self.net, parameters)
|
|
25
|
-
results = train(
|
|
27
|
+
results = train(
|
|
28
|
+
self.net,
|
|
29
|
+
self.trainloader,
|
|
30
|
+
self.valloader,
|
|
31
|
+
self.local_epochs,
|
|
32
|
+
DEVICE,
|
|
33
|
+
)
|
|
26
34
|
return get_weights(self.net), len(self.trainloader.dataset), results
|
|
27
35
|
|
|
28
36
|
def evaluate(self, parameters, config):
|
|
@@ -31,13 +39,16 @@ class FlowerClient(NumPyClient):
|
|
|
31
39
|
return loss, len(self.valloader.dataset), {"accuracy": accuracy}
|
|
32
40
|
|
|
33
41
|
|
|
34
|
-
def client_fn(
|
|
42
|
+
def client_fn(context: Context):
|
|
35
43
|
# Load model and data
|
|
36
44
|
net = Net().to(DEVICE)
|
|
37
|
-
|
|
45
|
+
partition_id = context.node_config["partition-id"]
|
|
46
|
+
num_partitions = context.node_config["num-partitions"]
|
|
47
|
+
trainloader, valloader = load_data(partition_id, num_partitions)
|
|
48
|
+
local_epochs = context.run_config["local-epochs"]
|
|
38
49
|
|
|
39
50
|
# Return Client instance
|
|
40
|
-
return FlowerClient(net, trainloader, valloader).to_client()
|
|
51
|
+
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
|
|
41
52
|
|
|
42
53
|
|
|
43
54
|
# Flower ClientApp
|
|
@@ -4,6 +4,7 @@ import warnings
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from flwr.client import NumPyClient, ClientApp
|
|
7
|
+
from flwr.common import Context
|
|
7
8
|
from flwr_datasets import FederatedDataset
|
|
8
9
|
from sklearn.linear_model import LogisticRegression
|
|
9
10
|
from sklearn.metrics import log_loss
|
|
@@ -66,10 +67,12 @@ class FlowerClient(NumPyClient):
|
|
|
66
67
|
|
|
67
68
|
return loss, len(self.X_test), {"accuracy": accuracy}
|
|
68
69
|
|
|
69
|
-
fds = FederatedDataset(dataset="mnist", partitioners={"train": 2})
|
|
70
70
|
|
|
71
|
-
def client_fn(
|
|
72
|
-
|
|
71
|
+
def client_fn(context: Context):
|
|
72
|
+
partition_id = context.node_config["partition-id"]
|
|
73
|
+
num_partitions = context.node_config["num-partitions"]
|
|
74
|
+
fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
|
|
75
|
+
dataset = fds.load_partition(partition_id, "train").with_format("numpy")
|
|
73
76
|
|
|
74
77
|
X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
|
|
75
78
|
|
|
@@ -1,25 +1,37 @@
|
|
|
1
1
|
"""$project_name: A Flower / TensorFlow app."""
|
|
2
2
|
|
|
3
3
|
from flwr.client import NumPyClient, ClientApp
|
|
4
|
+
from flwr.common import Context
|
|
4
5
|
|
|
5
6
|
from $import_name.task import load_data, load_model
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
# Define Flower Client and client_fn
|
|
9
10
|
class FlowerClient(NumPyClient):
|
|
10
|
-
def __init__(
|
|
11
|
+
def __init__(
|
|
12
|
+
self, model, x_train, y_train, x_test, y_test, epochs, batch_size, verbose
|
|
13
|
+
):
|
|
11
14
|
self.model = model
|
|
12
15
|
self.x_train = x_train
|
|
13
16
|
self.y_train = y_train
|
|
14
17
|
self.x_test = x_test
|
|
15
18
|
self.y_test = y_test
|
|
19
|
+
self.epochs = epochs
|
|
20
|
+
self.batch_size = batch_size
|
|
21
|
+
self.verbose = verbose
|
|
16
22
|
|
|
17
23
|
def get_parameters(self, config):
|
|
18
24
|
return self.model.get_weights()
|
|
19
25
|
|
|
20
26
|
def fit(self, parameters, config):
|
|
21
27
|
self.model.set_weights(parameters)
|
|
22
|
-
self.model.fit(
|
|
28
|
+
self.model.fit(
|
|
29
|
+
self.x_train,
|
|
30
|
+
self.y_train,
|
|
31
|
+
epochs=self.epochs,
|
|
32
|
+
batch_size=self.batch_size,
|
|
33
|
+
verbose=self.verbose,
|
|
34
|
+
)
|
|
23
35
|
return self.model.get_weights(), len(self.x_train), {}
|
|
24
36
|
|
|
25
37
|
def evaluate(self, parameters, config):
|
|
@@ -28,13 +40,21 @@ class FlowerClient(NumPyClient):
|
|
|
28
40
|
return loss, len(self.x_test), {"accuracy": accuracy}
|
|
29
41
|
|
|
30
42
|
|
|
31
|
-
def client_fn(
|
|
43
|
+
def client_fn(context: Context):
|
|
32
44
|
# Load model and data
|
|
33
45
|
net = load_model()
|
|
34
|
-
|
|
46
|
+
|
|
47
|
+
partition_id = context.node_config["partition-id"]
|
|
48
|
+
num_partitions = context.node_config["num-partitions"]
|
|
49
|
+
x_train, y_train, x_test, y_test = load_data(partition_id, num_partitions)
|
|
50
|
+
epochs = context.run_config["local-epochs"]
|
|
51
|
+
batch_size = context.run_config["batch-size"]
|
|
52
|
+
verbose = context.run_config.get("verbose")
|
|
35
53
|
|
|
36
54
|
# Return Client instance
|
|
37
|
-
return FlowerClient(
|
|
55
|
+
return FlowerClient(
|
|
56
|
+
net, x_train, y_train, x_test, y_test, epochs, batch_size, verbose
|
|
57
|
+
).to_client()
|
|
38
58
|
|
|
39
59
|
|
|
40
60
|
# Flower ClientApp
|