flwr-nightly 1.12.0.dev20240907__py3-none-any.whl → 1.12.0.dev20240914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +1 -2
- flwr/cli/config_utils.py +10 -10
- flwr/cli/install.py +1 -2
- flwr/cli/new/new.py +26 -40
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
- flwr/cli/run/run.py +6 -7
- flwr/cli/utils.py +2 -2
- flwr/client/app.py +14 -14
- flwr/client/client_app.py +5 -5
- flwr/client/clientapp/app.py +2 -2
- flwr/client/dpfedavg_numpy_client.py +6 -7
- flwr/client/grpc_adapter_client/connection.py +4 -3
- flwr/client/grpc_client/connection.py +4 -3
- flwr/client/grpc_rere_client/client_interceptor.py +5 -5
- flwr/client/grpc_rere_client/connection.py +5 -4
- flwr/client/grpc_rere_client/grpc_adapter.py +2 -2
- flwr/client/message_handler/message_handler.py +3 -3
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +25 -25
- flwr/client/mod/utils.py +1 -3
- flwr/client/node_state.py +2 -2
- flwr/client/numpy_client.py +8 -8
- flwr/client/rest_client/connection.py +5 -4
- flwr/client/supernode/app.py +7 -8
- flwr/common/address.py +2 -2
- flwr/common/config.py +8 -8
- flwr/common/constant.py +12 -1
- flwr/common/differential_privacy.py +2 -2
- flwr/common/dp.py +1 -3
- flwr/common/exit_handlers.py +3 -3
- flwr/common/grpc.py +2 -1
- flwr/common/logger.py +3 -3
- flwr/common/object_ref.py +3 -3
- flwr/common/record/configsrecord.py +3 -3
- flwr/common/record/metricsrecord.py +3 -3
- flwr/common/record/parametersrecord.py +3 -2
- flwr/common/record/recordset.py +1 -1
- flwr/common/record/typeddict.py +23 -10
- flwr/common/recordset_compat.py +7 -5
- flwr/common/retry_invoker.py +6 -17
- flwr/common/secure_aggregation/crypto/shamir.py +10 -10
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +2 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +16 -16
- flwr/common/secure_aggregation/quantization.py +7 -7
- flwr/common/secure_aggregation/secaggplus_utils.py +3 -5
- flwr/common/serde.py +11 -9
- flwr/common/telemetry.py +5 -5
- flwr/common/typing.py +19 -19
- flwr/common/version.py +2 -3
- flwr/server/app.py +18 -18
- flwr/server/client_manager.py +6 -6
- flwr/server/compat/app_utils.py +2 -3
- flwr/server/driver/driver.py +3 -2
- flwr/server/driver/grpc_driver.py +7 -7
- flwr/server/driver/inmemory_driver.py +5 -4
- flwr/server/history.py +8 -9
- flwr/server/run_serverapp.py +5 -6
- flwr/server/server.py +36 -36
- flwr/server/strategy/aggregate.py +13 -13
- flwr/server/strategy/bulyan.py +8 -8
- flwr/server/strategy/dp_adaptive_clipping.py +20 -20
- flwr/server/strategy/dp_fixed_clipping.py +19 -19
- flwr/server/strategy/dpfedavg_adaptive.py +6 -6
- flwr/server/strategy/dpfedavg_fixed.py +10 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +11 -11
- flwr/server/strategy/fedadagrad.py +8 -8
- flwr/server/strategy/fedadam.py +8 -8
- flwr/server/strategy/fedavg.py +16 -16
- flwr/server/strategy/fedavg_android.py +16 -16
- flwr/server/strategy/fedavgm.py +8 -8
- flwr/server/strategy/fedmedian.py +4 -4
- flwr/server/strategy/fedopt.py +5 -5
- flwr/server/strategy/fedprox.py +6 -6
- flwr/server/strategy/fedtrimmedavg.py +8 -8
- flwr/server/strategy/fedxgb_bagging.py +11 -11
- flwr/server/strategy/fedxgb_cyclic.py +9 -9
- flwr/server/strategy/fedxgb_nn_avg.py +5 -5
- flwr/server/strategy/fedyogi.py +8 -8
- flwr/server/strategy/krum.py +8 -8
- flwr/server/strategy/qfedavg.py +15 -15
- flwr/server/strategy/strategy.py +10 -10
- flwr/server/superlink/driver/driver_grpc.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +6 -6
- flwr/server/superlink/ffs/disk_ffs.py +4 -4
- flwr/server/superlink/ffs/ffs.py +4 -4
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -2
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +9 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +5 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +2 -2
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +2 -3
- flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
- flwr/server/superlink/fleet/vce/backend/raybackend.py +26 -17
- flwr/server/superlink/fleet/vce/vce_api.py +6 -6
- flwr/server/superlink/state/in_memory_state.py +18 -18
- flwr/server/superlink/state/sqlite_state.py +22 -21
- flwr/server/superlink/state/state.py +7 -7
- flwr/server/utils/tensorboard.py +4 -4
- flwr/server/utils/validator.py +2 -2
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +22 -22
- flwr/simulation/app.py +8 -8
- flwr/simulation/ray_transport/ray_actor.py +23 -23
- flwr/simulation/run_simulation.py +16 -4
- flwr/superexec/app.py +4 -4
- flwr/superexec/deployment.py +2 -2
- flwr/superexec/exec_grpc.py +2 -2
- flwr/superexec/exec_servicer.py +3 -2
- {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240914.dist-info}/METADATA +4 -6
- {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240914.dist-info}/RECORD +118 -118
- {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240914.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240914.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.12.0.dev20240907.dist-info → flwr_nightly-1.12.0.dev20240914.dist-info}/entry_points.txt +0 -0
flwr/cli/build.py
CHANGED
|
@@ -17,12 +17,11 @@
|
|
|
17
17
|
import os
|
|
18
18
|
import zipfile
|
|
19
19
|
from pathlib import Path
|
|
20
|
-
from typing import Optional
|
|
20
|
+
from typing import Annotated, Optional
|
|
21
21
|
|
|
22
22
|
import pathspec
|
|
23
23
|
import tomli_w
|
|
24
24
|
import typer
|
|
25
|
-
from typing_extensions import Annotated
|
|
26
25
|
|
|
27
26
|
from .config_utils import load_and_validate
|
|
28
27
|
from .utils import get_sha256_hash, is_valid_project_name
|
flwr/cli/config_utils.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
import zipfile
|
|
18
18
|
from io import BytesIO
|
|
19
19
|
from pathlib import Path
|
|
20
|
-
from typing import IO, Any,
|
|
20
|
+
from typing import IO, Any, Optional, Union, get_args
|
|
21
21
|
|
|
22
22
|
import tomli
|
|
23
23
|
|
|
@@ -25,7 +25,7 @@ from flwr.common import object_ref
|
|
|
25
25
|
from flwr.common.typing import UserConfigValue
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def get_fab_config(fab_file: Union[Path, bytes]) ->
|
|
28
|
+
def get_fab_config(fab_file: Union[Path, bytes]) -> dict[str, Any]:
|
|
29
29
|
"""Extract the config from a FAB file or path.
|
|
30
30
|
|
|
31
31
|
Parameters
|
|
@@ -62,7 +62,7 @@ def get_fab_config(fab_file: Union[Path, bytes]) -> Dict[str, Any]:
|
|
|
62
62
|
return conf
|
|
63
63
|
|
|
64
64
|
|
|
65
|
-
def get_fab_metadata(fab_file: Union[Path, bytes]) ->
|
|
65
|
+
def get_fab_metadata(fab_file: Union[Path, bytes]) -> tuple[str, str]:
|
|
66
66
|
"""Extract the fab_id and the fab_version from a FAB file or path.
|
|
67
67
|
|
|
68
68
|
Parameters
|
|
@@ -87,7 +87,7 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]:
|
|
|
87
87
|
def load_and_validate(
|
|
88
88
|
path: Optional[Path] = None,
|
|
89
89
|
check_module: bool = True,
|
|
90
|
-
) ->
|
|
90
|
+
) -> tuple[Optional[dict[str, Any]], list[str], list[str]]:
|
|
91
91
|
"""Load and validate pyproject.toml as dict.
|
|
92
92
|
|
|
93
93
|
Returns
|
|
@@ -116,7 +116,7 @@ def load_and_validate(
|
|
|
116
116
|
return (config, errors, warnings)
|
|
117
117
|
|
|
118
118
|
|
|
119
|
-
def load(toml_path: Path) -> Optional[
|
|
119
|
+
def load(toml_path: Path) -> Optional[dict[str, Any]]:
|
|
120
120
|
"""Load pyproject.toml and return as dict."""
|
|
121
121
|
if not toml_path.is_file():
|
|
122
122
|
return None
|
|
@@ -125,7 +125,7 @@ def load(toml_path: Path) -> Optional[Dict[str, Any]]:
|
|
|
125
125
|
return load_from_string(toml_file.read())
|
|
126
126
|
|
|
127
127
|
|
|
128
|
-
def _validate_run_config(config_dict:
|
|
128
|
+
def _validate_run_config(config_dict: dict[str, Any], errors: list[str]) -> None:
|
|
129
129
|
for key, value in config_dict.items():
|
|
130
130
|
if isinstance(value, dict):
|
|
131
131
|
_validate_run_config(config_dict[key], errors)
|
|
@@ -137,7 +137,7 @@ def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None
|
|
|
137
137
|
|
|
138
138
|
|
|
139
139
|
# pylint: disable=too-many-branches
|
|
140
|
-
def validate_fields(config:
|
|
140
|
+
def validate_fields(config: dict[str, Any]) -> tuple[bool, list[str], list[str]]:
|
|
141
141
|
"""Validate pyproject.toml fields."""
|
|
142
142
|
errors = []
|
|
143
143
|
warnings = []
|
|
@@ -183,10 +183,10 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]
|
|
|
183
183
|
|
|
184
184
|
|
|
185
185
|
def validate(
|
|
186
|
-
config:
|
|
186
|
+
config: dict[str, Any],
|
|
187
187
|
check_module: bool = True,
|
|
188
188
|
project_dir: Optional[Union[str, Path]] = None,
|
|
189
|
-
) ->
|
|
189
|
+
) -> tuple[bool, list[str], list[str]]:
|
|
190
190
|
"""Validate pyproject.toml."""
|
|
191
191
|
is_valid, errors, warnings = validate_fields(config)
|
|
192
192
|
|
|
@@ -210,7 +210,7 @@ def validate(
|
|
|
210
210
|
return True, [], []
|
|
211
211
|
|
|
212
212
|
|
|
213
|
-
def load_from_string(toml_content: str) -> Optional[
|
|
213
|
+
def load_from_string(toml_content: str) -> Optional[dict[str, Any]]:
|
|
214
214
|
"""Load TOML content from a string and return as dict."""
|
|
215
215
|
try:
|
|
216
216
|
data = tomli.loads(toml_content)
|
flwr/cli/install.py
CHANGED
|
@@ -21,10 +21,9 @@ import tempfile
|
|
|
21
21
|
import zipfile
|
|
22
22
|
from io import BytesIO
|
|
23
23
|
from pathlib import Path
|
|
24
|
-
from typing import IO, Optional, Union
|
|
24
|
+
from typing import IO, Annotated, Optional, Union
|
|
25
25
|
|
|
26
26
|
import typer
|
|
27
|
-
from typing_extensions import Annotated
|
|
28
27
|
|
|
29
28
|
from flwr.common.config import get_flwr_dir
|
|
30
29
|
|
flwr/cli/new/new.py
CHANGED
|
@@ -18,10 +18,9 @@ import re
|
|
|
18
18
|
from enum import Enum
|
|
19
19
|
from pathlib import Path
|
|
20
20
|
from string import Template
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import Annotated, Optional
|
|
22
22
|
|
|
23
23
|
import typer
|
|
24
|
-
from typing_extensions import Annotated
|
|
25
24
|
|
|
26
25
|
from ..utils import (
|
|
27
26
|
is_valid_project_name,
|
|
@@ -70,7 +69,7 @@ def load_template(name: str) -> str:
|
|
|
70
69
|
return tpl_file.read()
|
|
71
70
|
|
|
72
71
|
|
|
73
|
-
def render_template(template: str, data:
|
|
72
|
+
def render_template(template: str, data: dict[str, str]) -> str:
|
|
74
73
|
"""Render template."""
|
|
75
74
|
tpl_file = load_template(template)
|
|
76
75
|
tpl = Template(tpl_file)
|
|
@@ -85,7 +84,7 @@ def create_file(file_path: Path, content: str) -> None:
|
|
|
85
84
|
file_path.write_text(content)
|
|
86
85
|
|
|
87
86
|
|
|
88
|
-
def render_and_create(file_path: Path, template: str, context:
|
|
87
|
+
def render_and_create(file_path: Path, template: str, context: dict[str, str]) -> None:
|
|
89
88
|
"""Render template and write to file."""
|
|
90
89
|
content = render_template(template, context)
|
|
91
90
|
create_file(file_path, content)
|
|
@@ -136,36 +135,23 @@ def new(
|
|
|
136
135
|
username = prompt_text("Please provide your Flower username")
|
|
137
136
|
|
|
138
137
|
if framework is not None:
|
|
139
|
-
|
|
138
|
+
framework_str = str(framework.value)
|
|
140
139
|
else:
|
|
141
|
-
|
|
140
|
+
framework_str = prompt_options(
|
|
142
141
|
"Please select ML framework by typing in the number",
|
|
143
142
|
[mlf.value for mlf in MlFramework],
|
|
144
143
|
)
|
|
145
|
-
selected_value = [
|
|
146
|
-
name
|
|
147
|
-
for name, value in vars(MlFramework).items()
|
|
148
|
-
if value == framework_value
|
|
149
|
-
]
|
|
150
|
-
framework_str_upper = selected_value[0]
|
|
151
|
-
|
|
152
|
-
framework_str = framework_str_upper.lower()
|
|
153
144
|
|
|
154
145
|
llm_challenge_str = None
|
|
155
|
-
if framework_str ==
|
|
146
|
+
if framework_str == MlFramework.FLOWERTUNE:
|
|
156
147
|
llm_challenge_value = prompt_options(
|
|
157
148
|
"Please select LLM challenge by typing in the number",
|
|
158
149
|
sorted([challenge.value for challenge in LlmChallengeName]),
|
|
159
150
|
)
|
|
160
|
-
|
|
161
|
-
name
|
|
162
|
-
for name, value in vars(LlmChallengeName).items()
|
|
163
|
-
if value == llm_challenge_value
|
|
164
|
-
]
|
|
165
|
-
llm_challenge_str = selected_value[0]
|
|
166
|
-
llm_challenge_str = llm_challenge_str.lower()
|
|
151
|
+
llm_challenge_str = llm_challenge_value.lower()
|
|
167
152
|
|
|
168
|
-
|
|
153
|
+
if framework_str == MlFramework.BASELINE:
|
|
154
|
+
framework_str = "baseline"
|
|
169
155
|
|
|
170
156
|
print(
|
|
171
157
|
typer.style(
|
|
@@ -176,19 +162,21 @@ def new(
|
|
|
176
162
|
)
|
|
177
163
|
|
|
178
164
|
context = {
|
|
179
|
-
"framework_str":
|
|
165
|
+
"framework_str": framework_str,
|
|
180
166
|
"import_name": import_name.replace("-", "_"),
|
|
181
167
|
"package_name": package_name,
|
|
182
168
|
"project_name": app_name,
|
|
183
169
|
"username": username,
|
|
184
170
|
}
|
|
185
171
|
|
|
172
|
+
template_name = framework_str.lower()
|
|
173
|
+
|
|
186
174
|
# List of files to render
|
|
187
175
|
if llm_challenge_str:
|
|
188
176
|
files = {
|
|
189
177
|
".gitignore": {"template": "app/.gitignore.tpl"},
|
|
190
|
-
"pyproject.toml": {"template": f"app/pyproject.{
|
|
191
|
-
"README.md": {"template": f"app/README.{
|
|
178
|
+
"pyproject.toml": {"template": f"app/pyproject.{template_name}.toml.tpl"},
|
|
179
|
+
"README.md": {"template": f"app/README.{template_name}.md.tpl"},
|
|
192
180
|
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
|
|
193
181
|
f"{import_name}/server_app.py": {
|
|
194
182
|
"template": "app/code/flwr_tune/server_app.py.tpl"
|
|
@@ -235,44 +223,42 @@ def new(
|
|
|
235
223
|
files = {
|
|
236
224
|
".gitignore": {"template": "app/.gitignore.tpl"},
|
|
237
225
|
"README.md": {"template": "app/README.md.tpl"},
|
|
238
|
-
"pyproject.toml": {"template": f"app/pyproject.{
|
|
226
|
+
"pyproject.toml": {"template": f"app/pyproject.{template_name}.toml.tpl"},
|
|
239
227
|
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
|
|
240
228
|
f"{import_name}/server_app.py": {
|
|
241
|
-
"template": f"app/code/server.{
|
|
229
|
+
"template": f"app/code/server.{template_name}.py.tpl"
|
|
242
230
|
},
|
|
243
231
|
f"{import_name}/client_app.py": {
|
|
244
|
-
"template": f"app/code/client.{
|
|
232
|
+
"template": f"app/code/client.{template_name}.py.tpl"
|
|
245
233
|
},
|
|
246
234
|
}
|
|
247
235
|
|
|
248
236
|
# Depending on the framework, generate task.py file
|
|
249
237
|
frameworks_with_tasks = [
|
|
250
|
-
MlFramework.PYTORCH.value
|
|
251
|
-
MlFramework.JAX.value
|
|
252
|
-
MlFramework.HUGGINGFACE.value
|
|
253
|
-
MlFramework.MLX.value
|
|
254
|
-
MlFramework.TENSORFLOW.value
|
|
238
|
+
MlFramework.PYTORCH.value,
|
|
239
|
+
MlFramework.JAX.value,
|
|
240
|
+
MlFramework.HUGGINGFACE.value,
|
|
241
|
+
MlFramework.MLX.value,
|
|
242
|
+
MlFramework.TENSORFLOW.value,
|
|
255
243
|
]
|
|
256
244
|
if framework_str in frameworks_with_tasks:
|
|
257
245
|
files[f"{import_name}/task.py"] = {
|
|
258
|
-
"template": f"app/code/task.{
|
|
246
|
+
"template": f"app/code/task.{template_name}.py.tpl"
|
|
259
247
|
}
|
|
260
248
|
|
|
261
|
-
if
|
|
249
|
+
if framework_str == "baseline":
|
|
262
250
|
# Include additional files for baseline template
|
|
263
251
|
for file_name in ["model", "dataset", "strategy", "utils", "__init__"]:
|
|
264
252
|
files[f"{import_name}/{file_name}.py"] = {
|
|
265
|
-
"template": f"app/code/{file_name}.{
|
|
253
|
+
"template": f"app/code/{file_name}.{template_name}.py.tpl"
|
|
266
254
|
}
|
|
267
255
|
|
|
268
256
|
# Replace README.md
|
|
269
|
-
files["README.md"]["template"] = f"app/README.{
|
|
257
|
+
files["README.md"]["template"] = f"app/README.{template_name}.md.tpl"
|
|
270
258
|
|
|
271
259
|
# Add LICENSE
|
|
272
260
|
files["LICENSE"] = {"template": "app/LICENSE.tpl"}
|
|
273
261
|
|
|
274
|
-
context["framework_str"] = "baseline"
|
|
275
|
-
|
|
276
262
|
for file_path, value in files.items():
|
|
277
263
|
render_and_create(
|
|
278
264
|
file_path=project_dir / file_path,
|
|
@@ -1,18 +1,11 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
+
import torch
|
|
3
4
|
from flwr.client import ClientApp, NumPyClient
|
|
4
5
|
from flwr.common import Context
|
|
5
6
|
from transformers import AutoModelForSequenceClassification
|
|
6
7
|
|
|
7
|
-
from $import_name.task import
|
|
8
|
-
get_weights,
|
|
9
|
-
load_data,
|
|
10
|
-
set_weights,
|
|
11
|
-
train,
|
|
12
|
-
test,
|
|
13
|
-
CHECKPOINT,
|
|
14
|
-
DEVICE,
|
|
15
|
-
)
|
|
8
|
+
from $import_name.task import get_weights, load_data, set_weights, test, train
|
|
16
9
|
|
|
17
10
|
|
|
18
11
|
# Flower client
|
|
@@ -22,37 +15,34 @@ class FlowerClient(NumPyClient):
|
|
|
22
15
|
self.trainloader = trainloader
|
|
23
16
|
self.testloader = testloader
|
|
24
17
|
self.local_epochs = local_epochs
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
return get_weights(self.net)
|
|
28
|
-
|
|
29
|
-
def set_parameters(self, parameters):
|
|
30
|
-
set_weights(self.net, parameters)
|
|
18
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
19
|
+
self.net.to(self.device)
|
|
31
20
|
|
|
32
21
|
def fit(self, parameters, config):
|
|
33
|
-
self.
|
|
34
|
-
train(
|
|
35
|
-
|
|
36
|
-
self.trainloader,
|
|
37
|
-
epochs=self.local_epochs,
|
|
38
|
-
)
|
|
39
|
-
return self.get_parameters(config={}), len(self.trainloader), {}
|
|
22
|
+
set_weights(self.net, parameters)
|
|
23
|
+
train(self.net, self.trainloader, epochs=self.local_epochs, device=self.device)
|
|
24
|
+
return get_weights(self.net), len(self.trainloader), {}
|
|
40
25
|
|
|
41
26
|
def evaluate(self, parameters, config):
|
|
42
|
-
self.
|
|
43
|
-
loss, accuracy = test(self.net, self.testloader)
|
|
27
|
+
set_weights(self.net, parameters)
|
|
28
|
+
loss, accuracy = test(self.net, self.testloader, self.device)
|
|
44
29
|
return float(loss), len(self.testloader), {"accuracy": accuracy}
|
|
45
30
|
|
|
46
31
|
|
|
47
32
|
def client_fn(context: Context):
|
|
48
|
-
# Load model and data
|
|
49
|
-
net = AutoModelForSequenceClassification.from_pretrained(
|
|
50
|
-
CHECKPOINT, num_labels=2
|
|
51
|
-
).to(DEVICE)
|
|
52
33
|
|
|
34
|
+
# Get this client's dataset partition
|
|
53
35
|
partition_id = context.node_config["partition-id"]
|
|
54
36
|
num_partitions = context.node_config["num-partitions"]
|
|
55
|
-
|
|
37
|
+
model_name = context.run_config["model-name"]
|
|
38
|
+
trainloader, valloader = load_data(partition_id, num_partitions, model_name)
|
|
39
|
+
|
|
40
|
+
# Load model
|
|
41
|
+
num_labels = context.run_config["num-labels"]
|
|
42
|
+
net = AutoModelForSequenceClassification.from_pretrained(
|
|
43
|
+
model_name, num_labels=num_labels
|
|
44
|
+
)
|
|
45
|
+
|
|
56
46
|
local_epochs = context.run_config["local-epochs"]
|
|
57
47
|
|
|
58
48
|
# Return Client instance
|
|
@@ -17,9 +17,6 @@ class FlowerClient(NumPyClient):
|
|
|
17
17
|
self.batch_size = batch_size
|
|
18
18
|
self.verbose = verbose
|
|
19
19
|
|
|
20
|
-
def get_parameters(self, config):
|
|
21
|
-
return self.model.get_weights()
|
|
22
|
-
|
|
23
20
|
def fit(self, parameters, config):
|
|
24
21
|
self.model.set_weights(parameters)
|
|
25
22
|
self.model.fit(
|
|
@@ -1,18 +1,33 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
from flwr.common import Context
|
|
4
|
-
from flwr.server.strategy import FedAvg
|
|
3
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
5
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
|
+
from flwr.server.strategy import FedAvg
|
|
6
|
+
from transformers import AutoModelForSequenceClassification
|
|
7
|
+
|
|
8
|
+
from $import_name.task import get_weights
|
|
6
9
|
|
|
7
10
|
|
|
8
11
|
def server_fn(context: Context):
|
|
9
12
|
# Read from config
|
|
10
13
|
num_rounds = context.run_config["num-server-rounds"]
|
|
14
|
+
fraction_fit = context.run_config["fraction-fit"]
|
|
15
|
+
|
|
16
|
+
# Initialize global model
|
|
17
|
+
model_name = context.run_config["model-name"]
|
|
18
|
+
num_labels = context.run_config["num-labels"]
|
|
19
|
+
net = AutoModelForSequenceClassification.from_pretrained(
|
|
20
|
+
model_name, num_labels=num_labels
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
weights = get_weights(net)
|
|
24
|
+
initial_parameters = ndarrays_to_parameters(weights)
|
|
11
25
|
|
|
12
26
|
# Define strategy
|
|
13
27
|
strategy = FedAvg(
|
|
14
|
-
fraction_fit=
|
|
28
|
+
fraction_fit=fraction_fit,
|
|
15
29
|
fraction_evaluate=1.0,
|
|
30
|
+
initial_parameters=initial_parameters,
|
|
16
31
|
)
|
|
17
32
|
config = ServerConfig(num_rounds=num_rounds)
|
|
18
33
|
|
|
@@ -4,24 +4,25 @@ import warnings
|
|
|
4
4
|
from collections import OrderedDict
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
+
import transformers
|
|
8
|
+
from datasets.utils.logging import disable_progress_bar
|
|
7
9
|
from evaluate import load as load_metric
|
|
10
|
+
from flwr_datasets import FederatedDataset
|
|
11
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
8
12
|
from torch.optim import AdamW
|
|
9
13
|
from torch.utils.data import DataLoader
|
|
10
14
|
from transformers import AutoTokenizer, DataCollatorWithPadding
|
|
11
15
|
|
|
12
|
-
from flwr_datasets import FederatedDataset
|
|
13
|
-
from flwr_datasets.partitioner import IidPartitioner
|
|
14
|
-
|
|
15
|
-
|
|
16
16
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
18
|
+
disable_progress_bar()
|
|
19
|
+
transformers.logging.set_verbosity_error()
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
fds = None # Cache FederatedDataset
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
def load_data(partition_id: int, num_partitions: int):
|
|
25
|
+
def load_data(partition_id: int, num_partitions: int, model_name: str):
|
|
25
26
|
"""Load IMDB data (training and eval)"""
|
|
26
27
|
# Only initialize `FederatedDataset` once
|
|
27
28
|
global fds
|
|
@@ -35,10 +36,12 @@ def load_data(partition_id: int, num_partitions: int):
|
|
|
35
36
|
# Divide data: 80% train, 20% test
|
|
36
37
|
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
37
38
|
|
|
38
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
39
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
39
40
|
|
|
40
41
|
def tokenize_function(examples):
|
|
41
|
-
return tokenizer(
|
|
42
|
+
return tokenizer(
|
|
43
|
+
examples["text"], truncation=True, add_special_tokens=True, max_length=512
|
|
44
|
+
)
|
|
42
45
|
|
|
43
46
|
partition_train_test = partition_train_test.map(tokenize_function, batched=True)
|
|
44
47
|
partition_train_test = partition_train_test.remove_columns("text")
|
|
@@ -59,12 +62,12 @@ def load_data(partition_id: int, num_partitions: int):
|
|
|
59
62
|
return trainloader, testloader
|
|
60
63
|
|
|
61
64
|
|
|
62
|
-
def train(net, trainloader, epochs):
|
|
65
|
+
def train(net, trainloader, epochs, device):
|
|
63
66
|
optimizer = AdamW(net.parameters(), lr=5e-5)
|
|
64
67
|
net.train()
|
|
65
68
|
for _ in range(epochs):
|
|
66
69
|
for batch in trainloader:
|
|
67
|
-
batch = {k: v.to(
|
|
70
|
+
batch = {k: v.to(device) for k, v in batch.items()}
|
|
68
71
|
outputs = net(**batch)
|
|
69
72
|
loss = outputs.loss
|
|
70
73
|
loss.backward()
|
|
@@ -72,12 +75,12 @@ def train(net, trainloader, epochs):
|
|
|
72
75
|
optimizer.zero_grad()
|
|
73
76
|
|
|
74
77
|
|
|
75
|
-
def test(net, testloader):
|
|
78
|
+
def test(net, testloader, device):
|
|
76
79
|
metric = load_metric("accuracy")
|
|
77
80
|
loss = 0
|
|
78
81
|
net.eval()
|
|
79
82
|
for batch in testloader:
|
|
80
|
-
batch = {k: v.to(
|
|
83
|
+
batch = {k: v.to(device) for k, v in batch.items()}
|
|
81
84
|
with torch.no_grad():
|
|
82
85
|
outputs = net(**batch)
|
|
83
86
|
logits = outputs.logits
|
|
@@ -8,7 +8,7 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
11
|
+
"flwr[simulation]>=1.11.0",
|
|
12
12
|
"flwr-datasets>=0.3.0",
|
|
13
13
|
"torch==2.2.1",
|
|
14
14
|
"transformers>=4.30.0,<5.0",
|
|
@@ -29,10 +29,18 @@ clientapp = "$import_name.client_app:app"
|
|
|
29
29
|
|
|
30
30
|
[tool.flwr.app.config]
|
|
31
31
|
num-server-rounds = 3
|
|
32
|
+
fraction-fit = 0.5
|
|
32
33
|
local-epochs = 1
|
|
34
|
+
model-name = "prajjwal1/bert-tiny" # Set a larger model if you have access to more GPU resources
|
|
35
|
+
num-labels = 2
|
|
33
36
|
|
|
34
37
|
[tool.flwr.federations]
|
|
35
38
|
default = "localhost"
|
|
36
39
|
|
|
37
40
|
[tool.flwr.federations.localhost]
|
|
38
41
|
options.num-supernodes = 10
|
|
42
|
+
|
|
43
|
+
[tool.flwr.federations.localhost-gpu]
|
|
44
|
+
options.num-supernodes = 10
|
|
45
|
+
options.backend.client-resources.num-cpus = 4 # each ClientApp assumes to use 4CPUs
|
|
46
|
+
options.backend.client-resources.num-gpus = 0.25 # at most 4 ClientApps will run in a given GPU
|
flwr/cli/run/run.py
CHANGED
|
@@ -20,10 +20,9 @@ import subprocess
|
|
|
20
20
|
import sys
|
|
21
21
|
from logging import DEBUG
|
|
22
22
|
from pathlib import Path
|
|
23
|
-
from typing import
|
|
23
|
+
from typing import Annotated, Any, Optional
|
|
24
24
|
|
|
25
25
|
import typer
|
|
26
|
-
from typing_extensions import Annotated
|
|
27
26
|
|
|
28
27
|
from flwr.cli.build import build
|
|
29
28
|
from flwr.cli.config_utils import load_and_validate
|
|
@@ -52,7 +51,7 @@ def run(
|
|
|
52
51
|
typer.Argument(help="Name of the federation to run the app on."),
|
|
53
52
|
] = None,
|
|
54
53
|
config_overrides: Annotated[
|
|
55
|
-
Optional[
|
|
54
|
+
Optional[list[str]],
|
|
56
55
|
typer.Option(
|
|
57
56
|
"--run-config",
|
|
58
57
|
"-c",
|
|
@@ -125,8 +124,8 @@ def run(
|
|
|
125
124
|
|
|
126
125
|
def _run_with_superexec(
|
|
127
126
|
app: Path,
|
|
128
|
-
federation_config:
|
|
129
|
-
config_overrides: Optional[
|
|
127
|
+
federation_config: dict[str, Any],
|
|
128
|
+
config_overrides: Optional[list[str]],
|
|
130
129
|
) -> None:
|
|
131
130
|
|
|
132
131
|
insecure_str = federation_config.get("insecure")
|
|
@@ -187,8 +186,8 @@ def _run_with_superexec(
|
|
|
187
186
|
|
|
188
187
|
def _run_without_superexec(
|
|
189
188
|
app: Optional[Path],
|
|
190
|
-
federation_config:
|
|
191
|
-
config_overrides: Optional[
|
|
189
|
+
federation_config: dict[str, Any],
|
|
190
|
+
config_overrides: Optional[list[str]],
|
|
192
191
|
federation: str,
|
|
193
192
|
) -> None:
|
|
194
193
|
try:
|
flwr/cli/utils.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
import hashlib
|
|
18
18
|
import re
|
|
19
19
|
from pathlib import Path
|
|
20
|
-
from typing import Callable,
|
|
20
|
+
from typing import Callable, Optional, cast
|
|
21
21
|
|
|
22
22
|
import typer
|
|
23
23
|
|
|
@@ -40,7 +40,7 @@ def prompt_text(
|
|
|
40
40
|
return cast(str, result)
|
|
41
41
|
|
|
42
42
|
|
|
43
|
-
def prompt_options(text: str, options:
|
|
43
|
+
def prompt_options(text: str, options: list[str]) -> str:
|
|
44
44
|
"""Ask user to select one of the given options and return the selected item."""
|
|
45
45
|
# Turn options into a list with index as in " [ 0] quickstart-pytorch"
|
|
46
46
|
options_formatted = [
|
flwr/client/app.py
CHANGED
|
@@ -18,10 +18,11 @@ import signal
|
|
|
18
18
|
import subprocess
|
|
19
19
|
import sys
|
|
20
20
|
import time
|
|
21
|
+
from contextlib import AbstractContextManager
|
|
21
22
|
from dataclasses import dataclass
|
|
22
23
|
from logging import ERROR, INFO, WARN
|
|
23
24
|
from pathlib import Path
|
|
24
|
-
from typing import Callable,
|
|
25
|
+
from typing import Callable, Optional, Union, cast
|
|
25
26
|
|
|
26
27
|
import grpc
|
|
27
28
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -35,6 +36,7 @@ from flwr.client.typing import ClientFnExt
|
|
|
35
36
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
|
|
36
37
|
from flwr.common.address import parse_address
|
|
37
38
|
from flwr.common.constant import (
|
|
39
|
+
CLIENTAPPIO_API_DEFAULT_ADDRESS,
|
|
38
40
|
MISSING_EXTRA_REST,
|
|
39
41
|
RUN_ID_NUM_BYTES,
|
|
40
42
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
@@ -60,8 +62,6 @@ from .message_handler.message_handler import handle_control_message
|
|
|
60
62
|
from .node_state import NodeState
|
|
61
63
|
from .numpy_client import NumPyClient
|
|
62
64
|
|
|
63
|
-
ADDRESS_CLIENTAPPIO_API_GRPC_RERE = "0.0.0.0:9094"
|
|
64
|
-
|
|
65
65
|
ISOLATION_MODE_SUBPROCESS = "subprocess"
|
|
66
66
|
ISOLATION_MODE_PROCESS = "process"
|
|
67
67
|
|
|
@@ -95,7 +95,7 @@ def start_client(
|
|
|
95
95
|
insecure: Optional[bool] = None,
|
|
96
96
|
transport: Optional[str] = None,
|
|
97
97
|
authentication_keys: Optional[
|
|
98
|
-
|
|
98
|
+
tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
99
99
|
] = None,
|
|
100
100
|
max_retries: Optional[int] = None,
|
|
101
101
|
max_wait_time: Optional[float] = None,
|
|
@@ -205,13 +205,13 @@ def start_client_internal(
|
|
|
205
205
|
insecure: Optional[bool] = None,
|
|
206
206
|
transport: Optional[str] = None,
|
|
207
207
|
authentication_keys: Optional[
|
|
208
|
-
|
|
208
|
+
tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
209
209
|
] = None,
|
|
210
210
|
max_retries: Optional[int] = None,
|
|
211
211
|
max_wait_time: Optional[float] = None,
|
|
212
212
|
flwr_path: Optional[Path] = None,
|
|
213
213
|
isolation: Optional[str] = None,
|
|
214
|
-
supernode_address: Optional[str] =
|
|
214
|
+
supernode_address: Optional[str] = CLIENTAPPIO_API_DEFAULT_ADDRESS,
|
|
215
215
|
) -> None:
|
|
216
216
|
"""Start a Flower client node which connects to a Flower server.
|
|
217
217
|
|
|
@@ -266,7 +266,7 @@ def start_client_internal(
|
|
|
266
266
|
by the SueprNode and communicates using gRPC at the address
|
|
267
267
|
`supernode_address`. If `process`, the `ClientApp` runs in a separate isolated
|
|
268
268
|
process and communicates using gRPC at the address `supernode_address`.
|
|
269
|
-
supernode_address : Optional[str] (default: `
|
|
269
|
+
supernode_address : Optional[str] (default: `CLIENTAPPIO_API_DEFAULT_ADDRESS`)
|
|
270
270
|
The SuperNode gRPC server address.
|
|
271
271
|
"""
|
|
272
272
|
if insecure is None:
|
|
@@ -357,7 +357,7 @@ def start_client_internal(
|
|
|
357
357
|
# NodeState gets initialized when the first connection is established
|
|
358
358
|
node_state: Optional[NodeState] = None
|
|
359
359
|
|
|
360
|
-
runs:
|
|
360
|
+
runs: dict[int, Run] = {}
|
|
361
361
|
|
|
362
362
|
while not app_state_tracker.interrupt:
|
|
363
363
|
sleep_duration: int = 0
|
|
@@ -690,7 +690,7 @@ def start_numpy_client(
|
|
|
690
690
|
)
|
|
691
691
|
|
|
692
692
|
|
|
693
|
-
def _init_connection(transport: Optional[str], server_address: str) ->
|
|
693
|
+
def _init_connection(transport: Optional[str], server_address: str) -> tuple[
|
|
694
694
|
Callable[
|
|
695
695
|
[
|
|
696
696
|
str,
|
|
@@ -698,10 +698,10 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
|
698
698
|
RetryInvoker,
|
|
699
699
|
int,
|
|
700
700
|
Union[bytes, str, None],
|
|
701
|
-
Optional[
|
|
701
|
+
Optional[tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]],
|
|
702
702
|
],
|
|
703
|
-
|
|
704
|
-
|
|
703
|
+
AbstractContextManager[
|
|
704
|
+
tuple[
|
|
705
705
|
Callable[[], Optional[Message]],
|
|
706
706
|
Callable[[Message], None],
|
|
707
707
|
Optional[Callable[[], Optional[int]]],
|
|
@@ -712,7 +712,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
|
712
712
|
],
|
|
713
713
|
],
|
|
714
714
|
str,
|
|
715
|
-
|
|
715
|
+
type[Exception],
|
|
716
716
|
]:
|
|
717
717
|
# Parse IP address
|
|
718
718
|
parsed_address = parse_address(server_address)
|
|
@@ -770,7 +770,7 @@ class _AppStateTracker:
|
|
|
770
770
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
771
771
|
|
|
772
772
|
|
|
773
|
-
def run_clientappio_api_grpc(address: str) ->
|
|
773
|
+
def run_clientappio_api_grpc(address: str) -> tuple[grpc.Server, ClientAppIoServicer]:
|
|
774
774
|
"""Run ClientAppIo API gRPC server."""
|
|
775
775
|
clientappio_servicer: grpc.Server = ClientAppIoServicer()
|
|
776
776
|
clientappio_add_servicer_to_server_fn = add_ClientAppIoServicer_to_server
|