flwr-nightly 1.10.0.dev20240624__py3-none-any.whl → 1.10.0.dev20240722__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +18 -4
- flwr/cli/config_utils.py +36 -14
- flwr/cli/install.py +17 -1
- flwr/cli/new/new.py +31 -20
- flwr/cli/new/templates/app/code/client.hf.py.tpl +11 -3
- flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +15 -10
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +12 -3
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +13 -3
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +2 -2
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.hf.py.tpl +16 -11
- flwr/cli/new/templates/app/code/server.jax.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +11 -7
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +15 -8
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +16 -10
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
- flwr/cli/new/templates/app/code/task.hf.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +2 -2
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +17 -16
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
- flwr/cli/run/run.py +135 -51
- flwr/client/__init__.py +2 -0
- flwr/client/app.py +63 -26
- flwr/client/client_app.py +49 -4
- flwr/client/grpc_adapter_client/connection.py +3 -2
- flwr/client/grpc_client/connection.py +3 -2
- flwr/client/grpc_rere_client/connection.py +17 -6
- flwr/client/message_handler/message_handler.py +3 -4
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/connection.py +19 -8
- flwr/client/supernode/app.py +60 -21
- flwr/client/typing.py +1 -0
- flwr/common/config.py +87 -2
- flwr/common/constant.py +6 -0
- flwr/common/context.py +26 -1
- flwr/common/logger.py +38 -0
- flwr/common/message.py +0 -17
- flwr/common/serde.py +45 -0
- flwr/common/telemetry.py +17 -0
- flwr/common/typing.py +5 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +24 -19
- flwr/proto/driver_pb2.pyi +21 -1
- flwr/proto/exec_pb2.py +16 -11
- flwr/proto/exec_pb2.pyi +22 -1
- flwr/proto/run_pb2.py +12 -7
- flwr/proto/run_pb2.pyi +22 -1
- flwr/proto/task_pb2.py +7 -8
- flwr/server/__init__.py +2 -0
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/grpc_driver.py +82 -140
- flwr/server/run_serverapp.py +40 -15
- flwr/server/server_app.py +56 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/superlink/driver/driver_servicer.py +18 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
- flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
- flwr/server/superlink/fleet/vce/vce_api.py +149 -122
- flwr/server/superlink/state/in_memory_state.py +15 -7
- flwr/server/superlink/state/sqlite_state.py +27 -12
- flwr/server/superlink/state/state.py +7 -2
- flwr/server/superlink/state/utils.py +6 -0
- flwr/server/typing.py +2 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
- flwr/simulation/app.py +52 -36
- flwr/simulation/ray_transport/ray_actor.py +15 -19
- flwr/simulation/ray_transport/ray_client_proxy.py +33 -13
- flwr/simulation/run_simulation.py +237 -66
- flwr/superexec/app.py +14 -7
- flwr/superexec/deployment.py +186 -0
- flwr/superexec/exec_grpc.py +5 -1
- flwr/superexec/exec_servicer.py +4 -1
- flwr/superexec/executor.py +18 -0
- flwr/superexec/simulation.py +151 -0
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/METADATA +3 -2
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/RECORD +95 -88
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240624.dist-info → flwr_nightly-1.10.0.dev20240722.dist-info}/entry_points.txt +0 -0
flwr/cli/build.py
CHANGED
|
@@ -20,6 +20,7 @@ from pathlib import Path
|
|
|
20
20
|
from typing import Optional
|
|
21
21
|
|
|
22
22
|
import pathspec
|
|
23
|
+
import tomli_w
|
|
23
24
|
import typer
|
|
24
25
|
from typing_extensions import Annotated
|
|
25
26
|
|
|
@@ -31,7 +32,7 @@ from .utils import get_sha256_hash, is_valid_project_name
|
|
|
31
32
|
def build(
|
|
32
33
|
directory: Annotated[
|
|
33
34
|
Optional[Path],
|
|
34
|
-
typer.Option(help="
|
|
35
|
+
typer.Option(help="Path of the Flower project to bundle into a FAB"),
|
|
35
36
|
] = None,
|
|
36
37
|
) -> str:
|
|
37
38
|
"""Build a Flower project into a Flower App Bundle (FAB).
|
|
@@ -85,7 +86,7 @@ def build(
|
|
|
85
86
|
|
|
86
87
|
# Set the name of the zip file
|
|
87
88
|
fab_filename = (
|
|
88
|
-
f"{conf['
|
|
89
|
+
f"{conf['tool']['flwr']['app']['publisher']}"
|
|
89
90
|
f".{directory.name}"
|
|
90
91
|
f".{conf['project']['version'].replace('.', '-')}.fab"
|
|
91
92
|
)
|
|
@@ -93,15 +94,28 @@ def build(
|
|
|
93
94
|
|
|
94
95
|
allowed_extensions = {".py", ".toml", ".md"}
|
|
95
96
|
|
|
97
|
+
# Remove the 'federations' field from 'tool.flwr' if it exists
|
|
98
|
+
if (
|
|
99
|
+
"tool" in conf
|
|
100
|
+
and "flwr" in conf["tool"]
|
|
101
|
+
and "federations" in conf["tool"]["flwr"]
|
|
102
|
+
):
|
|
103
|
+
del conf["tool"]["flwr"]["federations"]
|
|
104
|
+
|
|
105
|
+
toml_contents = tomli_w.dumps(conf)
|
|
106
|
+
|
|
96
107
|
with zipfile.ZipFile(fab_filename, "w", zipfile.ZIP_DEFLATED) as fab_file:
|
|
108
|
+
fab_file.writestr("pyproject.toml", toml_contents)
|
|
109
|
+
|
|
110
|
+
# Continue with adding other files
|
|
97
111
|
for root, _, files in os.walk(directory, topdown=True):
|
|
98
|
-
# Filter directories and files based on .gitignore
|
|
99
112
|
files = [
|
|
100
113
|
f
|
|
101
114
|
for f in files
|
|
102
115
|
if not ignore_spec.match_file(Path(root) / f)
|
|
103
116
|
and f != fab_filename
|
|
104
117
|
and Path(f).suffix in allowed_extensions
|
|
118
|
+
and f != "pyproject.toml" # Exclude the original pyproject.toml
|
|
105
119
|
]
|
|
106
120
|
|
|
107
121
|
for file in files:
|
|
@@ -118,7 +132,7 @@ def build(
|
|
|
118
132
|
fab_file.writestr(".info/CONTENT", list_file_content)
|
|
119
133
|
|
|
120
134
|
typer.secho(
|
|
121
|
-
f"🎊 Successfully built {fab_filename}
|
|
135
|
+
f"🎊 Successfully built {fab_filename}", fg=typer.colors.GREEN, bold=True
|
|
122
136
|
)
|
|
123
137
|
|
|
124
138
|
return fab_filename
|
flwr/cli/config_utils.py
CHANGED
|
@@ -17,11 +17,12 @@
|
|
|
17
17
|
import zipfile
|
|
18
18
|
from io import BytesIO
|
|
19
19
|
from pathlib import Path
|
|
20
|
-
from typing import IO, Any, Dict, List, Optional, Tuple, Union
|
|
20
|
+
from typing import IO, Any, Dict, List, Optional, Tuple, Union, get_args
|
|
21
21
|
|
|
22
22
|
import tomli
|
|
23
23
|
|
|
24
24
|
from flwr.common import object_ref
|
|
25
|
+
from flwr.common.typing import UserConfigValue
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
|
|
@@ -60,7 +61,7 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
|
|
|
60
61
|
|
|
61
62
|
return (
|
|
62
63
|
conf["project"]["version"],
|
|
63
|
-
f"{conf['
|
|
64
|
+
f"{conf['tool']['flwr']['app']['publisher']}/{conf['project']['name']}",
|
|
64
65
|
)
|
|
65
66
|
|
|
66
67
|
|
|
@@ -108,6 +109,17 @@ def load(path: Optional[Path] = None) -> Optional[Dict[str, Any]]:
|
|
|
108
109
|
return load_from_string(toml_file.read())
|
|
109
110
|
|
|
110
111
|
|
|
112
|
+
def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None:
|
|
113
|
+
for key, value in config_dict.items():
|
|
114
|
+
if isinstance(value, dict):
|
|
115
|
+
_validate_run_config(config_dict[key], errors)
|
|
116
|
+
elif not isinstance(value, get_args(UserConfigValue)):
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"The value for key {key} needs to be of type `int`, `float`, "
|
|
119
|
+
"`bool, `str`, or a `dict` of those.",
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
111
123
|
# pylint: disable=too-many-branches
|
|
112
124
|
def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
|
|
113
125
|
"""Validate pyproject.toml fields."""
|
|
@@ -128,18 +140,28 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]
|
|
|
128
140
|
if "authors" not in config["project"]:
|
|
129
141
|
warnings.append('Recommended property "authors" missing in [project]')
|
|
130
142
|
|
|
131
|
-
if
|
|
132
|
-
|
|
143
|
+
if (
|
|
144
|
+
"tool" not in config
|
|
145
|
+
or "flwr" not in config["tool"]
|
|
146
|
+
or "app" not in config["tool"]["flwr"]
|
|
147
|
+
):
|
|
148
|
+
errors.append("Missing [tool.flwr.app] section")
|
|
133
149
|
else:
|
|
134
|
-
if "publisher" not in config["
|
|
135
|
-
errors.append('Property "publisher" missing in [
|
|
136
|
-
if "
|
|
137
|
-
|
|
150
|
+
if "publisher" not in config["tool"]["flwr"]["app"]:
|
|
151
|
+
errors.append('Property "publisher" missing in [tool.flwr.app]')
|
|
152
|
+
if "config" in config["tool"]["flwr"]["app"]:
|
|
153
|
+
_validate_run_config(config["tool"]["flwr"]["app"]["config"], errors)
|
|
154
|
+
if "components" not in config["tool"]["flwr"]["app"]:
|
|
155
|
+
errors.append("Missing [tool.flwr.app.components] section")
|
|
138
156
|
else:
|
|
139
|
-
if "serverapp" not in config["
|
|
140
|
-
errors.append(
|
|
141
|
-
|
|
142
|
-
|
|
157
|
+
if "serverapp" not in config["tool"]["flwr"]["app"]["components"]:
|
|
158
|
+
errors.append(
|
|
159
|
+
'Property "serverapp" missing in [tool.flwr.app.components]'
|
|
160
|
+
)
|
|
161
|
+
if "clientapp" not in config["tool"]["flwr"]["app"]["components"]:
|
|
162
|
+
errors.append(
|
|
163
|
+
'Property "clientapp" missing in [tool.flwr.app.components]'
|
|
164
|
+
)
|
|
143
165
|
|
|
144
166
|
return len(errors) == 0, errors, warnings
|
|
145
167
|
|
|
@@ -155,14 +177,14 @@ def validate(
|
|
|
155
177
|
|
|
156
178
|
# Validate serverapp
|
|
157
179
|
is_valid, reason = object_ref.validate(
|
|
158
|
-
config["
|
|
180
|
+
config["tool"]["flwr"]["app"]["components"]["serverapp"], check_module
|
|
159
181
|
)
|
|
160
182
|
if not is_valid and isinstance(reason, str):
|
|
161
183
|
return False, [reason], []
|
|
162
184
|
|
|
163
185
|
# Validate clientapp
|
|
164
186
|
is_valid, reason = object_ref.validate(
|
|
165
|
-
config["
|
|
187
|
+
config["tool"]["flwr"]["app"]["components"]["clientapp"], check_module
|
|
166
188
|
)
|
|
167
189
|
|
|
168
190
|
if not is_valid and isinstance(reason, str):
|
flwr/cli/install.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import shutil
|
|
19
|
+
import subprocess
|
|
19
20
|
import tempfile
|
|
20
21
|
import zipfile
|
|
21
22
|
from io import BytesIO
|
|
@@ -149,7 +150,7 @@ def validate_and_install(
|
|
|
149
150
|
)
|
|
150
151
|
raise typer.Exit(code=1)
|
|
151
152
|
|
|
152
|
-
publisher = config["
|
|
153
|
+
publisher = config["tool"]["flwr"]["app"]["publisher"]
|
|
153
154
|
project_name = config["project"]["name"]
|
|
154
155
|
version = config["project"]["version"]
|
|
155
156
|
|
|
@@ -192,6 +193,21 @@ def validate_and_install(
|
|
|
192
193
|
else:
|
|
193
194
|
shutil.copy2(item, install_dir / item.name)
|
|
194
195
|
|
|
196
|
+
try:
|
|
197
|
+
subprocess.run(
|
|
198
|
+
["pip", "install", "-e", install_dir, "--no-deps"],
|
|
199
|
+
capture_output=True,
|
|
200
|
+
text=True,
|
|
201
|
+
check=True,
|
|
202
|
+
)
|
|
203
|
+
except subprocess.CalledProcessError as e:
|
|
204
|
+
typer.secho(
|
|
205
|
+
f"❌ Failed to `pip install` package(s) from {install_dir}:\n{e.stderr}",
|
|
206
|
+
fg=typer.colors.RED,
|
|
207
|
+
bold=True,
|
|
208
|
+
)
|
|
209
|
+
raise typer.Exit(code=1) from e
|
|
210
|
+
|
|
195
211
|
typer.secho(
|
|
196
212
|
f"🎊 Successfully installed {project_name} to {install_dir}.",
|
|
197
213
|
fg=typer.colors.GREEN,
|
flwr/cli/new/new.py
CHANGED
|
@@ -14,9 +14,9 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower command line interface `new` command."""
|
|
16
16
|
|
|
17
|
-
import os
|
|
18
17
|
import re
|
|
19
18
|
from enum import Enum
|
|
19
|
+
from pathlib import Path
|
|
20
20
|
from string import Template
|
|
21
21
|
from typing import Dict, Optional
|
|
22
22
|
|
|
@@ -59,10 +59,10 @@ class TemplateNotFound(Exception):
|
|
|
59
59
|
|
|
60
60
|
def load_template(name: str) -> str:
|
|
61
61
|
"""Load template from template directory and return as text."""
|
|
62
|
-
tpl_dir =
|
|
63
|
-
tpl_file_path =
|
|
62
|
+
tpl_dir = (Path(__file__).parent / "templates").absolute()
|
|
63
|
+
tpl_file_path = tpl_dir / name
|
|
64
64
|
|
|
65
|
-
if not
|
|
65
|
+
if not tpl_file_path.is_file():
|
|
66
66
|
raise TemplateNotFound(f"Template '{name}' not found")
|
|
67
67
|
|
|
68
68
|
with open(tpl_file_path, encoding="utf-8") as tpl_file:
|
|
@@ -78,14 +78,13 @@ def render_template(template: str, data: Dict[str, str]) -> str:
|
|
|
78
78
|
return tpl.template
|
|
79
79
|
|
|
80
80
|
|
|
81
|
-
def create_file(file_path:
|
|
81
|
+
def create_file(file_path: Path, content: str) -> None:
|
|
82
82
|
"""Create file including all nessecary directories and write content into file."""
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
f.write(content)
|
|
83
|
+
file_path.parent.mkdir(exist_ok=True)
|
|
84
|
+
file_path.write_text(content)
|
|
86
85
|
|
|
87
86
|
|
|
88
|
-
def render_and_create(file_path:
|
|
87
|
+
def render_and_create(file_path: Path, template: str, context: Dict[str, str]) -> None:
|
|
89
88
|
"""Render template and write to file."""
|
|
90
89
|
content = render_template(template, context)
|
|
91
90
|
create_file(file_path, content)
|
|
@@ -117,6 +116,21 @@ def new(
|
|
|
117
116
|
default=sanitize_project_name(project_name),
|
|
118
117
|
)
|
|
119
118
|
|
|
119
|
+
# Set project directory path
|
|
120
|
+
package_name = re.sub(r"[-_.]+", "-", project_name).lower()
|
|
121
|
+
import_name = package_name.replace("-", "_")
|
|
122
|
+
project_dir = Path.cwd() / package_name
|
|
123
|
+
|
|
124
|
+
if project_dir.exists():
|
|
125
|
+
if not typer.confirm(
|
|
126
|
+
typer.style(
|
|
127
|
+
f"\n💬 {project_name} already exists, do you want to override it?",
|
|
128
|
+
fg=typer.colors.MAGENTA,
|
|
129
|
+
bold=True,
|
|
130
|
+
)
|
|
131
|
+
):
|
|
132
|
+
return
|
|
133
|
+
|
|
120
134
|
if username is None:
|
|
121
135
|
username = prompt_text("Please provide your Flower username")
|
|
122
136
|
|
|
@@ -136,6 +150,7 @@ def new(
|
|
|
136
150
|
|
|
137
151
|
framework_str = framework_str.lower()
|
|
138
152
|
|
|
153
|
+
llm_challenge_str = None
|
|
139
154
|
if framework_str == "flowertune":
|
|
140
155
|
llm_challenge_value = prompt_options(
|
|
141
156
|
"Please select LLM challenge by typing in the number",
|
|
@@ -157,12 +172,6 @@ def new(
|
|
|
157
172
|
)
|
|
158
173
|
)
|
|
159
174
|
|
|
160
|
-
# Set project directory path
|
|
161
|
-
cwd = os.getcwd()
|
|
162
|
-
package_name = re.sub(r"[-_.]+", "-", project_name).lower()
|
|
163
|
-
import_name = package_name.replace("-", "_")
|
|
164
|
-
project_dir = os.path.join(cwd, package_name)
|
|
165
|
-
|
|
166
175
|
context = {
|
|
167
176
|
"project_name": project_name,
|
|
168
177
|
"package_name": package_name,
|
|
@@ -171,7 +180,7 @@ def new(
|
|
|
171
180
|
}
|
|
172
181
|
|
|
173
182
|
# List of files to render
|
|
174
|
-
if
|
|
183
|
+
if llm_challenge_str:
|
|
175
184
|
files = {
|
|
176
185
|
".gitignore": {"template": "app/.gitignore.tpl"},
|
|
177
186
|
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
|
|
@@ -228,10 +237,10 @@ def new(
|
|
|
228
237
|
"README.md": {"template": "app/README.md.tpl"},
|
|
229
238
|
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
|
|
230
239
|
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
|
|
231
|
-
f"{import_name}/
|
|
240
|
+
f"{import_name}/server_app.py": {
|
|
232
241
|
"template": f"app/code/server.{framework_str}.py.tpl"
|
|
233
242
|
},
|
|
234
|
-
f"{import_name}/
|
|
243
|
+
f"{import_name}/client_app.py": {
|
|
235
244
|
"template": f"app/code/client.{framework_str}.py.tpl"
|
|
236
245
|
},
|
|
237
246
|
}
|
|
@@ -251,7 +260,7 @@ def new(
|
|
|
251
260
|
|
|
252
261
|
for file_path, value in files.items():
|
|
253
262
|
render_and_create(
|
|
254
|
-
file_path=
|
|
263
|
+
file_path=project_dir / file_path,
|
|
255
264
|
template=value["template"],
|
|
256
265
|
context=context,
|
|
257
266
|
)
|
|
@@ -264,9 +273,11 @@ def new(
|
|
|
264
273
|
bold=True,
|
|
265
274
|
)
|
|
266
275
|
)
|
|
276
|
+
|
|
277
|
+
_add = " huggingface-cli login\n" if framework_str == "flowertune" else ""
|
|
267
278
|
print(
|
|
268
279
|
typer.style(
|
|
269
|
-
f" cd {package_name}\n" + " pip install -e .\n flwr run\n",
|
|
280
|
+
f" cd {package_name}\n" + " pip install -e .\n" + _add + " flwr run\n",
|
|
270
281
|
fg=typer.colors.BRIGHT_CYAN,
|
|
271
282
|
bold=True,
|
|
272
283
|
)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""$project_name: A Flower / HuggingFace Transformers app."""
|
|
2
2
|
|
|
3
3
|
from flwr.client import ClientApp, NumPyClient
|
|
4
|
+
from flwr.common import Context
|
|
4
5
|
from transformers import AutoModelForSequenceClassification
|
|
5
6
|
|
|
6
7
|
from $import_name.task import (
|
|
@@ -29,7 +30,11 @@ class FlowerClient(NumPyClient):
|
|
|
29
30
|
|
|
30
31
|
def fit(self, parameters, config):
|
|
31
32
|
self.set_parameters(parameters)
|
|
32
|
-
train(
|
|
33
|
+
train(
|
|
34
|
+
self.net,
|
|
35
|
+
self.trainloader,
|
|
36
|
+
epochs=int(self.context.run_config["local-epochs"]),
|
|
37
|
+
)
|
|
33
38
|
return self.get_parameters(config={}), len(self.trainloader), {}
|
|
34
39
|
|
|
35
40
|
def evaluate(self, parameters, config):
|
|
@@ -38,12 +43,15 @@ class FlowerClient(NumPyClient):
|
|
|
38
43
|
return float(loss), len(self.testloader), {"accuracy": accuracy}
|
|
39
44
|
|
|
40
45
|
|
|
41
|
-
def client_fn(
|
|
46
|
+
def client_fn(context: Context):
|
|
42
47
|
# Load model and data
|
|
43
48
|
net = AutoModelForSequenceClassification.from_pretrained(
|
|
44
49
|
CHECKPOINT, num_labels=2
|
|
45
50
|
).to(DEVICE)
|
|
46
|
-
|
|
51
|
+
|
|
52
|
+
partition_id = int(context.node_config["partition-id"])
|
|
53
|
+
num_partitions = int(context.node_config["num-partitions"])
|
|
54
|
+
trainloader, valloader = load_data(partition_id, num_partitions)
|
|
47
55
|
|
|
48
56
|
# Return Client instance
|
|
49
57
|
return FlowerClient(net, trainloader, valloader).to_client()
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
4
|
from flwr.client import NumPyClient, ClientApp
|
|
5
|
+
from flwr.common import Context
|
|
5
6
|
|
|
6
7
|
from $import_name.task import (
|
|
7
8
|
evaluation,
|
|
@@ -44,7 +45,7 @@ class FlowerClient(NumPyClient):
|
|
|
44
45
|
)
|
|
45
46
|
return float(loss), num_examples, {"loss": float(loss)}
|
|
46
47
|
|
|
47
|
-
def client_fn(
|
|
48
|
+
def client_fn(context: Context):
|
|
48
49
|
# Return Client instance
|
|
49
50
|
return FlowerClient().to_client()
|
|
50
51
|
|
|
@@ -4,6 +4,7 @@ import mlx.core as mx
|
|
|
4
4
|
import mlx.nn as nn
|
|
5
5
|
import mlx.optimizers as optim
|
|
6
6
|
from flwr.client import NumPyClient, ClientApp
|
|
7
|
+
from flwr.common import Context
|
|
7
8
|
|
|
8
9
|
from $import_name.task import (
|
|
9
10
|
batch_iterate,
|
|
@@ -19,17 +20,19 @@ from $import_name.task import (
|
|
|
19
20
|
# Define Flower Client and client_fn
|
|
20
21
|
class FlowerClient(NumPyClient):
|
|
21
22
|
def __init__(self, data):
|
|
22
|
-
num_layers =
|
|
23
|
-
hidden_dim =
|
|
23
|
+
num_layers = int(self.context.run_config["num-layers"])
|
|
24
|
+
hidden_dim = int(self.context.run_config["hidden-dim"])
|
|
24
25
|
num_classes = 10
|
|
25
|
-
batch_size =
|
|
26
|
-
|
|
27
|
-
|
|
26
|
+
batch_size = int(self.context.run_config["batch-size"])
|
|
27
|
+
learning_rate = float(self.context.run_config["lr"])
|
|
28
|
+
num_epochs = int(self.context.run_config["local-epochs"])
|
|
28
29
|
|
|
29
30
|
self.train_images, self.train_labels, self.test_images, self.test_labels = data
|
|
30
|
-
self.model = MLP(
|
|
31
|
-
|
|
32
|
-
|
|
31
|
+
self.model = MLP(
|
|
32
|
+
num_layers, self.train_images.shape[-1], hidden_dim, num_classes
|
|
33
|
+
)
|
|
34
|
+
self.optimizer = optim.SGD(learning_rate=learning_rate)
|
|
35
|
+
self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
|
|
33
36
|
self.num_epochs = num_epochs
|
|
34
37
|
self.batch_size = batch_size
|
|
35
38
|
|
|
@@ -57,8 +60,10 @@ class FlowerClient(NumPyClient):
|
|
|
57
60
|
return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}
|
|
58
61
|
|
|
59
62
|
|
|
60
|
-
def client_fn(
|
|
61
|
-
|
|
63
|
+
def client_fn(context: Context):
|
|
64
|
+
partition_id = int(context.node_config["partition-id"])
|
|
65
|
+
num_partitions = int(context.node_config["num-partitions"])
|
|
66
|
+
data = load_data(partition_id, num_partitions)
|
|
62
67
|
|
|
63
68
|
# Return Client instance
|
|
64
69
|
return FlowerClient(data).to_client()
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""$project_name: A Flower / NumPy app."""
|
|
2
2
|
|
|
3
3
|
from flwr.client import NumPyClient, ClientApp
|
|
4
|
+
from flwr.common import Context
|
|
4
5
|
import numpy as np
|
|
5
6
|
|
|
6
7
|
|
|
@@ -15,7 +16,7 @@ class FlowerClient(NumPyClient):
|
|
|
15
16
|
return float(0.0), 1, {"accuracy": float(1.0)}
|
|
16
17
|
|
|
17
18
|
|
|
18
|
-
def client_fn(
|
|
19
|
+
def client_fn(context: Context):
|
|
19
20
|
return FlowerClient().to_client()
|
|
20
21
|
|
|
21
22
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""$project_name: A Flower / PyTorch app."""
|
|
2
2
|
|
|
3
3
|
from flwr.client import NumPyClient, ClientApp
|
|
4
|
+
from flwr.common import Context
|
|
4
5
|
|
|
5
6
|
from $import_name.task import (
|
|
6
7
|
Net,
|
|
@@ -22,7 +23,13 @@ class FlowerClient(NumPyClient):
|
|
|
22
23
|
|
|
23
24
|
def fit(self, parameters, config):
|
|
24
25
|
set_weights(self.net, parameters)
|
|
25
|
-
results = train(
|
|
26
|
+
results = train(
|
|
27
|
+
self.net,
|
|
28
|
+
self.trainloader,
|
|
29
|
+
self.valloader,
|
|
30
|
+
int(self.context.run_config["local-epochs"]),
|
|
31
|
+
DEVICE,
|
|
32
|
+
)
|
|
26
33
|
return get_weights(self.net), len(self.trainloader.dataset), results
|
|
27
34
|
|
|
28
35
|
def evaluate(self, parameters, config):
|
|
@@ -31,10 +38,12 @@ class FlowerClient(NumPyClient):
|
|
|
31
38
|
return loss, len(self.valloader.dataset), {"accuracy": accuracy}
|
|
32
39
|
|
|
33
40
|
|
|
34
|
-
def client_fn(
|
|
41
|
+
def client_fn(context: Context):
|
|
35
42
|
# Load model and data
|
|
36
43
|
net = Net().to(DEVICE)
|
|
37
|
-
|
|
44
|
+
partition_id = int(context.node_config["partition-id"])
|
|
45
|
+
num_partitions = int(context.node_config["num-partitions"])
|
|
46
|
+
trainloader, valloader = load_data(partition_id, num_partitions)
|
|
38
47
|
|
|
39
48
|
# Return Client instance
|
|
40
49
|
return FlowerClient(net, trainloader, valloader).to_client()
|
|
@@ -4,6 +4,7 @@ import warnings
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from flwr.client import NumPyClient, ClientApp
|
|
7
|
+
from flwr.common import Context
|
|
7
8
|
from flwr_datasets import FederatedDataset
|
|
8
9
|
from sklearn.linear_model import LogisticRegression
|
|
9
10
|
from sklearn.metrics import log_loss
|
|
@@ -66,10 +67,12 @@ class FlowerClient(NumPyClient):
|
|
|
66
67
|
|
|
67
68
|
return loss, len(self.X_test), {"accuracy": accuracy}
|
|
68
69
|
|
|
69
|
-
fds = FederatedDataset(dataset="mnist", partitioners={"train": 2})
|
|
70
70
|
|
|
71
|
-
def client_fn(
|
|
72
|
-
|
|
71
|
+
def client_fn(context: Context):
|
|
72
|
+
partition_id = int(context.node_config["partition-id"])
|
|
73
|
+
num_partitions = int(context.node_config["num-partitions"])
|
|
74
|
+
fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
|
|
75
|
+
dataset = fds.load_partition(partition_id, "train").with_format("numpy")
|
|
73
76
|
|
|
74
77
|
X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
|
|
75
78
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""$project_name: A Flower / TensorFlow app."""
|
|
2
2
|
|
|
3
3
|
from flwr.client import NumPyClient, ClientApp
|
|
4
|
+
from flwr.common import Context
|
|
4
5
|
|
|
5
6
|
from $import_name.task import load_data, load_model
|
|
6
7
|
|
|
@@ -19,7 +20,13 @@ class FlowerClient(NumPyClient):
|
|
|
19
20
|
|
|
20
21
|
def fit(self, parameters, config):
|
|
21
22
|
self.model.set_weights(parameters)
|
|
22
|
-
self.model.fit(
|
|
23
|
+
self.model.fit(
|
|
24
|
+
self.x_train,
|
|
25
|
+
self.y_train,
|
|
26
|
+
epochs=int(self.context.run_config["local-epochs"]),
|
|
27
|
+
batch_size=int(self.context.run_config["batch-size"]),
|
|
28
|
+
verbose=bool(self.context.run_config.get("verbose")),
|
|
29
|
+
)
|
|
23
30
|
return self.model.get_weights(), len(self.x_train), {}
|
|
24
31
|
|
|
25
32
|
def evaluate(self, parameters, config):
|
|
@@ -28,10 +35,13 @@ class FlowerClient(NumPyClient):
|
|
|
28
35
|
return loss, len(self.x_test), {"accuracy": accuracy}
|
|
29
36
|
|
|
30
37
|
|
|
31
|
-
def client_fn(
|
|
38
|
+
def client_fn(context: Context):
|
|
32
39
|
# Load model and data
|
|
33
40
|
net = load_model()
|
|
34
|
-
|
|
41
|
+
|
|
42
|
+
partition_id = int(context.node_config["partition-id"])
|
|
43
|
+
num_partitions = int(context.node_config["num-partitions"])
|
|
44
|
+
x_train, y_train, x_test, y_test = load_data(partition_id, num_partitions)
|
|
35
45
|
|
|
36
46
|
# Return Client instance
|
|
37
47
|
return FlowerClient(net, x_train, y_train, x_test, y_test).to_client()
|
|
@@ -12,10 +12,10 @@ from flwr.client import ClientApp
|
|
|
12
12
|
from flwr.common import ndarrays_to_parameters
|
|
13
13
|
from flwr.server import ServerApp, ServerConfig
|
|
14
14
|
|
|
15
|
-
from $import_name.
|
|
15
|
+
from $import_name.client_app import gen_client_fn, get_parameters
|
|
16
16
|
from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
|
|
17
17
|
from $import_name.models import get_model
|
|
18
|
-
from $import_name.
|
|
18
|
+
from $import_name.server_app import fit_weighted_average, get_evaluate_fn, get_on_fit_config
|
|
19
19
|
|
|
20
20
|
# Avoid warnings
|
|
21
21
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
@@ -1,17 +1,22 @@
|
|
|
1
1
|
"""$project_name: A Flower / HuggingFace Transformers app."""
|
|
2
2
|
|
|
3
|
+
from flwr.common import Context
|
|
3
4
|
from flwr.server.strategy import FedAvg
|
|
4
|
-
from flwr.server import ServerApp, ServerConfig
|
|
5
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
6
|
|
|
6
7
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
fraction_evaluate=1.0,
|
|
11
|
-
)
|
|
8
|
+
def server_fn(context: Context):
|
|
9
|
+
# Read from config
|
|
10
|
+
num_rounds = int(context.run_config["num-server-rounds"])
|
|
12
11
|
|
|
13
|
-
#
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
)
|
|
12
|
+
# Define strategy
|
|
13
|
+
strategy = FedAvg(
|
|
14
|
+
fraction_fit=1.0,
|
|
15
|
+
fraction_evaluate=1.0,
|
|
16
|
+
)
|
|
17
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
18
|
+
|
|
19
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
20
|
+
|
|
21
|
+
# Create ServerApp
|
|
22
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,12 +1,19 @@
|
|
|
1
1
|
"""$project_name: A Flower / JAX app."""
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
from flwr.common import Context
|
|
4
|
+
from flwr.server.strategy import FedAvg
|
|
5
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
4
6
|
|
|
5
|
-
# Configure the strategy
|
|
6
|
-
strategy = fl.server.strategy.FedAvg()
|
|
7
7
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
8
|
+
def server_fn(context: Context):
|
|
9
|
+
# Read from config
|
|
10
|
+
num_rounds = int(context.run_config["num-server-rounds"])
|
|
11
|
+
|
|
12
|
+
# Define strategy
|
|
13
|
+
strategy = FedAvg()
|
|
14
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
15
|
+
|
|
16
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
17
|
+
|
|
18
|
+
# Create ServerApp
|
|
19
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,15 +1,19 @@
|
|
|
1
1
|
"""$project_name: A Flower / MLX app."""
|
|
2
2
|
|
|
3
|
-
from flwr.
|
|
3
|
+
from flwr.common import Context
|
|
4
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
4
5
|
from flwr.server.strategy import FedAvg
|
|
5
6
|
|
|
6
7
|
|
|
7
|
-
|
|
8
|
-
|
|
8
|
+
def server_fn(context: Context):
|
|
9
|
+
# Read from config
|
|
10
|
+
num_rounds = int(context.run_config["num-server-rounds"])
|
|
9
11
|
|
|
12
|
+
# Define strategy
|
|
13
|
+
strategy = FedAvg()
|
|
14
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
15
|
+
|
|
16
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
10
17
|
|
|
11
18
|
# Create ServerApp
|
|
12
|
-
app = ServerApp(
|
|
13
|
-
config=ServerConfig(num_rounds=3),
|
|
14
|
-
strategy=strategy,
|
|
15
|
-
)
|
|
19
|
+
app = ServerApp(server_fn=server_fn)
|