flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +7 -0
- flwr/cli/build.py +150 -0
- flwr/cli/config_utils.py +219 -0
- flwr/cli/example.py +3 -1
- flwr/cli/install.py +227 -0
- flwr/cli/new/new.py +179 -48
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/README.md.tpl +1 -5
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
- flwr/cli/run/run.py +168 -17
- flwr/cli/utils.py +75 -4
- flwr/client/__init__.py +6 -1
- flwr/client/app.py +239 -248
- flwr/client/client_app.py +70 -9
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +97 -0
- flwr/client/grpc_client/connection.py +18 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +127 -33
- flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
- flwr/client/heartbeat.py +74 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +7 -7
- flwr/client/mod/__init__.py +5 -5
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/comms_mods.py +4 -4
- flwr/client/mod/localdp_mod.py +9 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +177 -157
- flwr/client/supernode/__init__.py +26 -0
- flwr/client/supernode/app.py +464 -0
- flwr/client/typing.py +1 -0
- flwr/common/__init__.py +13 -11
- flwr/common/address.py +1 -1
- flwr/common/config.py +193 -0
- flwr/common/constant.py +42 -1
- flwr/common/context.py +26 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +6 -2
- flwr/common/logger.py +79 -8
- flwr/common/message.py +167 -105
- flwr/common/object_ref.py +126 -25
- flwr/common/record/__init__.py +1 -1
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/recordset_compat.py +8 -1
- flwr/common/retry_invoker.py +25 -13
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +209 -3
- flwr/common/telemetry.py +25 -0
- flwr/common/typing.py +38 -0
- flwr/common/version.py +14 -0
- flwr/proto/clientappio_pb2.py +41 -0
- flwr/proto/clientappio_pb2.pyi +110 -0
- flwr/proto/clientappio_pb2_grpc.py +101 -0
- flwr/proto/clientappio_pb2_grpc.pyi +40 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +26 -19
- flwr/proto/driver_pb2.pyi +34 -0
- flwr/proto/driver_pb2_grpc.py +70 -0
- flwr/proto/driver_pb2_grpc.pyi +28 -0
- flwr/proto/exec_pb2.py +43 -0
- flwr/proto/exec_pb2.pyi +95 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +29 -23
- flwr/proto/fleet_pb2.pyi +33 -0
- flwr/proto/fleet_pb2_grpc.py +102 -0
- flwr/proto/fleet_pb2_grpc.pyi +35 -0
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +122 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/proto/run_pb2.py +35 -0
- flwr/proto/run_pb2.pyi +76 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +7 -8
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/__init__.py +4 -8
- flwr/server/app.py +298 -350
- flwr/server/compat/app.py +6 -57
- flwr/server/compat/app_utils.py +5 -4
- flwr/server/compat/driver_client_proxy.py +29 -48
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/__init__.py +2 -0
- flwr/server/driver/driver.py +22 -132
- flwr/server/driver/grpc_driver.py +224 -74
- flwr/server/driver/inmemory_driver.py +183 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +121 -34
- flwr/server/server.py +11 -7
- flwr/server/server_app.py +59 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +3 -3
- flwr/server/strategy/dp_fixed_clipping.py +4 -3
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +51 -4
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +104 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
- flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
- flwr/server/superlink/fleet/vce/vce_api.py +190 -127
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +159 -42
- flwr/server/superlink/state/sqlite_state.py +243 -39
- flwr/server/superlink/state/state.py +81 -6
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +62 -0
- flwr/server/typing.py +2 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +67 -25
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
- flwr/simulation/__init__.py +7 -4
- flwr/simulation/app.py +67 -36
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +20 -46
- flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
- flwr/simulation/run_simulation.py +308 -92
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +184 -0
- flwr/superexec/deployment.py +185 -0
- flwr/superexec/exec_grpc.py +55 -0
- flwr/superexec/exec_servicer.py +70 -0
- flwr/superexec/executor.py +75 -0
- flwr/superexec/simulation.py +193 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
- flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
- flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
- flwr/cli/flower_toml.py +0 -140
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
- flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
- flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
- flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
flwr/cli/new/new.py
CHANGED
|
@@ -14,23 +14,43 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower command line interface `new` command."""
|
|
16
16
|
|
|
17
|
-
import
|
|
17
|
+
import re
|
|
18
18
|
from enum import Enum
|
|
19
|
+
from pathlib import Path
|
|
19
20
|
from string import Template
|
|
20
21
|
from typing import Dict, Optional
|
|
21
22
|
|
|
22
23
|
import typer
|
|
23
24
|
from typing_extensions import Annotated
|
|
24
25
|
|
|
25
|
-
from ..utils import
|
|
26
|
+
from ..utils import (
|
|
27
|
+
is_valid_project_name,
|
|
28
|
+
prompt_options,
|
|
29
|
+
prompt_text,
|
|
30
|
+
sanitize_project_name,
|
|
31
|
+
)
|
|
26
32
|
|
|
27
33
|
|
|
28
34
|
class MlFramework(str, Enum):
|
|
29
35
|
"""Available frameworks."""
|
|
30
36
|
|
|
31
|
-
NUMPY = "NumPy"
|
|
32
37
|
PYTORCH = "PyTorch"
|
|
33
38
|
TENSORFLOW = "TensorFlow"
|
|
39
|
+
SKLEARN = "sklearn"
|
|
40
|
+
HUGGINGFACE = "HuggingFace"
|
|
41
|
+
JAX = "JAX"
|
|
42
|
+
MLX = "MLX"
|
|
43
|
+
NUMPY = "NumPy"
|
|
44
|
+
FLOWERTUNE = "FlowerTune"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class LlmChallengeName(str, Enum):
|
|
48
|
+
"""Available LLM challenges."""
|
|
49
|
+
|
|
50
|
+
GENERALNLP = "GeneralNLP"
|
|
51
|
+
FINANCE = "Finance"
|
|
52
|
+
MEDICAL = "Medical"
|
|
53
|
+
CODE = "Code"
|
|
34
54
|
|
|
35
55
|
|
|
36
56
|
class TemplateNotFound(Exception):
|
|
@@ -39,10 +59,10 @@ class TemplateNotFound(Exception):
|
|
|
39
59
|
|
|
40
60
|
def load_template(name: str) -> str:
|
|
41
61
|
"""Load template from template directory and return as text."""
|
|
42
|
-
tpl_dir =
|
|
43
|
-
tpl_file_path =
|
|
62
|
+
tpl_dir = (Path(__file__).parent / "templates").absolute()
|
|
63
|
+
tpl_file_path = tpl_dir / name
|
|
44
64
|
|
|
45
|
-
if not
|
|
65
|
+
if not tpl_file_path.is_file():
|
|
46
66
|
raise TemplateNotFound(f"Template '{name}' not found")
|
|
47
67
|
|
|
48
68
|
with open(tpl_file_path, encoding="utf-8") as tpl_file:
|
|
@@ -53,47 +73,69 @@ def render_template(template: str, data: Dict[str, str]) -> str:
|
|
|
53
73
|
"""Render template."""
|
|
54
74
|
tpl_file = load_template(template)
|
|
55
75
|
tpl = Template(tpl_file)
|
|
56
|
-
|
|
57
|
-
|
|
76
|
+
if ".gitignore" not in template:
|
|
77
|
+
return tpl.substitute(data)
|
|
78
|
+
return tpl.template
|
|
58
79
|
|
|
59
80
|
|
|
60
|
-
def create_file(file_path:
|
|
81
|
+
def create_file(file_path: Path, content: str) -> None:
|
|
61
82
|
"""Create file including all nessecary directories and write content into file."""
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
f.write(content)
|
|
83
|
+
file_path.parent.mkdir(exist_ok=True)
|
|
84
|
+
file_path.write_text(content)
|
|
65
85
|
|
|
66
86
|
|
|
67
|
-
def render_and_create(file_path:
|
|
87
|
+
def render_and_create(file_path: Path, template: str, context: Dict[str, str]) -> None:
|
|
68
88
|
"""Render template and write to file."""
|
|
69
89
|
content = render_template(template, context)
|
|
70
90
|
create_file(file_path, content)
|
|
71
91
|
|
|
72
92
|
|
|
93
|
+
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
|
|
73
94
|
def new(
|
|
74
|
-
|
|
95
|
+
app_name: Annotated[
|
|
75
96
|
Optional[str],
|
|
76
|
-
typer.Argument(
|
|
97
|
+
typer.Argument(help="The name of the Flower App"),
|
|
77
98
|
] = None,
|
|
78
99
|
framework: Annotated[
|
|
79
100
|
Optional[MlFramework],
|
|
80
101
|
typer.Option(case_sensitive=False, help="The ML framework to use"),
|
|
81
102
|
] = None,
|
|
103
|
+
username: Annotated[
|
|
104
|
+
Optional[str],
|
|
105
|
+
typer.Option(case_sensitive=False, help="The Flower username of the author"),
|
|
106
|
+
] = None,
|
|
82
107
|
) -> None:
|
|
83
|
-
"""Create new Flower
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
108
|
+
"""Create new Flower App."""
|
|
109
|
+
if app_name is None:
|
|
110
|
+
app_name = prompt_text("Please provide the app name")
|
|
111
|
+
if not is_valid_project_name(app_name):
|
|
112
|
+
app_name = prompt_text(
|
|
113
|
+
"Please provide a name that only contains "
|
|
114
|
+
"characters in {'-', a-zA-Z', '0-9'}",
|
|
115
|
+
predicate=is_valid_project_name,
|
|
116
|
+
default=sanitize_project_name(app_name),
|
|
89
117
|
)
|
|
90
|
-
)
|
|
91
118
|
|
|
92
|
-
|
|
93
|
-
|
|
119
|
+
# Set project directory path
|
|
120
|
+
package_name = re.sub(r"[-_.]+", "-", app_name).lower()
|
|
121
|
+
import_name = package_name.replace("-", "_")
|
|
122
|
+
project_dir = Path.cwd() / package_name
|
|
123
|
+
|
|
124
|
+
if project_dir.exists():
|
|
125
|
+
if not typer.confirm(
|
|
126
|
+
typer.style(
|
|
127
|
+
f"\n💬 {app_name} already exists, do you want to override it?",
|
|
128
|
+
fg=typer.colors.MAGENTA,
|
|
129
|
+
bold=True,
|
|
130
|
+
)
|
|
131
|
+
):
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
if username is None:
|
|
135
|
+
username = prompt_text("Please provide your Flower username")
|
|
94
136
|
|
|
95
137
|
if framework is not None:
|
|
96
|
-
|
|
138
|
+
framework_str_upper = str(framework.value)
|
|
97
139
|
else:
|
|
98
140
|
framework_value = prompt_options(
|
|
99
141
|
"Please select ML framework by typing in the number",
|
|
@@ -104,50 +146,139 @@ def new(
|
|
|
104
146
|
for name, value in vars(MlFramework).items()
|
|
105
147
|
if value == framework_value
|
|
106
148
|
]
|
|
107
|
-
|
|
149
|
+
framework_str_upper = selected_value[0]
|
|
108
150
|
|
|
109
|
-
framework_str =
|
|
151
|
+
framework_str = framework_str_upper.lower()
|
|
110
152
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
153
|
+
llm_challenge_str = None
|
|
154
|
+
if framework_str == "flowertune":
|
|
155
|
+
llm_challenge_value = prompt_options(
|
|
156
|
+
"Please select LLM challenge by typing in the number",
|
|
157
|
+
sorted([challenge.value for challenge in LlmChallengeName]),
|
|
158
|
+
)
|
|
159
|
+
selected_value = [
|
|
160
|
+
name
|
|
161
|
+
for name, value in vars(LlmChallengeName).items()
|
|
162
|
+
if value == llm_challenge_value
|
|
163
|
+
]
|
|
164
|
+
llm_challenge_str = selected_value[0]
|
|
165
|
+
llm_challenge_str = llm_challenge_str.lower()
|
|
115
166
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
167
|
+
print(
|
|
168
|
+
typer.style(
|
|
169
|
+
f"\n🔨 Creating Flower App {app_name}...",
|
|
170
|
+
fg=typer.colors.GREEN,
|
|
171
|
+
bold=True,
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
context = {
|
|
176
|
+
"framework_str": framework_str_upper,
|
|
177
|
+
"import_name": import_name.replace("-", "_"),
|
|
178
|
+
"package_name": package_name,
|
|
179
|
+
"project_name": app_name,
|
|
180
|
+
"username": username,
|
|
125
181
|
}
|
|
126
182
|
|
|
127
|
-
#
|
|
128
|
-
if
|
|
129
|
-
files
|
|
183
|
+
# List of files to render
|
|
184
|
+
if llm_challenge_str:
|
|
185
|
+
files = {
|
|
186
|
+
".gitignore": {"template": "app/.gitignore.tpl"},
|
|
187
|
+
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
|
|
188
|
+
"README.md": {"template": f"app/README.{framework_str}.md.tpl"},
|
|
189
|
+
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
|
|
190
|
+
f"{import_name}/server.py": {
|
|
191
|
+
"template": "app/code/flwr_tune/server.py.tpl"
|
|
192
|
+
},
|
|
193
|
+
f"{import_name}/client.py": {
|
|
194
|
+
"template": "app/code/flwr_tune/client.py.tpl"
|
|
195
|
+
},
|
|
196
|
+
f"{import_name}/app.py": {"template": "app/code/flwr_tune/app.py.tpl"},
|
|
197
|
+
f"{import_name}/models.py": {
|
|
198
|
+
"template": "app/code/flwr_tune/models.py.tpl"
|
|
199
|
+
},
|
|
200
|
+
f"{import_name}/dataset.py": {
|
|
201
|
+
"template": "app/code/flwr_tune/dataset.py.tpl"
|
|
202
|
+
},
|
|
203
|
+
f"{import_name}/conf/config.yaml": {
|
|
204
|
+
"template": "app/code/flwr_tune/config.yaml.tpl"
|
|
205
|
+
},
|
|
206
|
+
f"{import_name}/conf/static_config.yaml": {
|
|
207
|
+
"template": "app/code/flwr_tune/static_config.yaml.tpl"
|
|
208
|
+
},
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
# Challenge specific context
|
|
212
|
+
fraction_fit = "0.2" if llm_challenge_str == "code" else "0.1"
|
|
213
|
+
if llm_challenge_str == "generalnlp":
|
|
214
|
+
challenge_name = "General NLP"
|
|
215
|
+
num_clients = "20"
|
|
216
|
+
dataset_name = "vicgalle/alpaca-gpt4"
|
|
217
|
+
elif llm_challenge_str == "finance":
|
|
218
|
+
challenge_name = "Finance"
|
|
219
|
+
num_clients = "50"
|
|
220
|
+
dataset_name = "FinGPT/fingpt-sentiment-train"
|
|
221
|
+
elif llm_challenge_str == "medical":
|
|
222
|
+
challenge_name = "Medical"
|
|
223
|
+
num_clients = "20"
|
|
224
|
+
dataset_name = "medalpaca/medical_meadow_medical_flashcards"
|
|
225
|
+
else:
|
|
226
|
+
challenge_name = "Code"
|
|
227
|
+
num_clients = "10"
|
|
228
|
+
dataset_name = "lucasmccabe-lmi/CodeAlpaca-20k"
|
|
130
229
|
|
|
131
|
-
|
|
230
|
+
context["llm_challenge_str"] = llm_challenge_str
|
|
231
|
+
context["fraction_fit"] = fraction_fit
|
|
232
|
+
context["challenge_name"] = challenge_name
|
|
233
|
+
context["num_clients"] = num_clients
|
|
234
|
+
context["dataset_name"] = dataset_name
|
|
235
|
+
else:
|
|
236
|
+
files = {
|
|
237
|
+
".gitignore": {"template": "app/.gitignore.tpl"},
|
|
238
|
+
"README.md": {"template": "app/README.md.tpl"},
|
|
239
|
+
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
|
|
240
|
+
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
|
|
241
|
+
f"{import_name}/server_app.py": {
|
|
242
|
+
"template": f"app/code/server.{framework_str}.py.tpl"
|
|
243
|
+
},
|
|
244
|
+
f"{import_name}/client_app.py": {
|
|
245
|
+
"template": f"app/code/client.{framework_str}.py.tpl"
|
|
246
|
+
},
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
# Depending on the framework, generate task.py file
|
|
250
|
+
frameworks_with_tasks = [
|
|
251
|
+
MlFramework.PYTORCH.value.lower(),
|
|
252
|
+
MlFramework.JAX.value.lower(),
|
|
253
|
+
MlFramework.HUGGINGFACE.value.lower(),
|
|
254
|
+
MlFramework.MLX.value.lower(),
|
|
255
|
+
MlFramework.TENSORFLOW.value.lower(),
|
|
256
|
+
]
|
|
257
|
+
if framework_str in frameworks_with_tasks:
|
|
258
|
+
files[f"{import_name}/task.py"] = {
|
|
259
|
+
"template": f"app/code/task.{framework_str}.py.tpl"
|
|
260
|
+
}
|
|
132
261
|
|
|
133
262
|
for file_path, value in files.items():
|
|
134
263
|
render_and_create(
|
|
135
|
-
file_path=
|
|
264
|
+
file_path=project_dir / file_path,
|
|
136
265
|
template=value["template"],
|
|
137
266
|
context=context,
|
|
138
267
|
)
|
|
139
268
|
|
|
140
269
|
print(
|
|
141
270
|
typer.style(
|
|
142
|
-
"🎊
|
|
143
|
-
"Use the following command to run your
|
|
271
|
+
"🎊 Flower App creation successful.\n\n"
|
|
272
|
+
"Use the following command to run your Flower App:\n",
|
|
144
273
|
fg=typer.colors.GREEN,
|
|
145
274
|
bold=True,
|
|
146
275
|
)
|
|
147
276
|
)
|
|
277
|
+
|
|
278
|
+
_add = " huggingface-cli login\n" if framework_str == "flowertune" else ""
|
|
148
279
|
print(
|
|
149
280
|
typer.style(
|
|
150
|
-
f" cd {
|
|
281
|
+
f" cd {package_name}\n" + " pip install -e .\n" + _add + " flwr run\n",
|
|
151
282
|
fg=typer.colors.BRIGHT_CYAN,
|
|
152
283
|
bold=True,
|
|
153
284
|
)
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
# Byte-compiled / optimized / DLL files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
|
|
6
|
+
# C extensions
|
|
7
|
+
*.so
|
|
8
|
+
|
|
9
|
+
# Distribution / packaging
|
|
10
|
+
.Python
|
|
11
|
+
build/
|
|
12
|
+
develop-eggs/
|
|
13
|
+
dist/
|
|
14
|
+
downloads/
|
|
15
|
+
eggs/
|
|
16
|
+
.eggs/
|
|
17
|
+
lib/
|
|
18
|
+
lib64/
|
|
19
|
+
parts/
|
|
20
|
+
sdist/
|
|
21
|
+
var/
|
|
22
|
+
wheels/
|
|
23
|
+
share/python-wheels/
|
|
24
|
+
*.egg-info/
|
|
25
|
+
.installed.cfg
|
|
26
|
+
*.egg
|
|
27
|
+
MANIFEST
|
|
28
|
+
|
|
29
|
+
# PyInstaller
|
|
30
|
+
# Usually these files are written by a python script from a template
|
|
31
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
32
|
+
*.manifest
|
|
33
|
+
*.spec
|
|
34
|
+
|
|
35
|
+
# Installer logs
|
|
36
|
+
pip-log.txt
|
|
37
|
+
pip-delete-this-directory.txt
|
|
38
|
+
|
|
39
|
+
# Unit test / coverage reports
|
|
40
|
+
htmlcov/
|
|
41
|
+
.tox/
|
|
42
|
+
.nox/
|
|
43
|
+
.coverage
|
|
44
|
+
.coverage.*
|
|
45
|
+
.cache
|
|
46
|
+
nosetests.xml
|
|
47
|
+
coverage.xml
|
|
48
|
+
*.cover
|
|
49
|
+
*.py,cover
|
|
50
|
+
.hypothesis/
|
|
51
|
+
.pytest_cache/
|
|
52
|
+
cover/
|
|
53
|
+
|
|
54
|
+
# Translations
|
|
55
|
+
*.mo
|
|
56
|
+
*.pot
|
|
57
|
+
|
|
58
|
+
# Django stuff:
|
|
59
|
+
*.log
|
|
60
|
+
local_settings.py
|
|
61
|
+
db.sqlite3
|
|
62
|
+
db.sqlite3-journal
|
|
63
|
+
|
|
64
|
+
# Flask stuff:
|
|
65
|
+
instance/
|
|
66
|
+
.webassets-cache
|
|
67
|
+
|
|
68
|
+
# Scrapy stuff:
|
|
69
|
+
.scrapy
|
|
70
|
+
|
|
71
|
+
# Sphinx documentation
|
|
72
|
+
docs/_build/
|
|
73
|
+
|
|
74
|
+
# PyBuilder
|
|
75
|
+
.pybuilder/
|
|
76
|
+
target/
|
|
77
|
+
|
|
78
|
+
# Jupyter Notebook
|
|
79
|
+
.ipynb_checkpoints
|
|
80
|
+
|
|
81
|
+
# IPython
|
|
82
|
+
profile_default/
|
|
83
|
+
ipython_config.py
|
|
84
|
+
|
|
85
|
+
# pyenv
|
|
86
|
+
# For a library or package, you might want to ignore these files since the code is
|
|
87
|
+
# intended to run in multiple environments; otherwise, check them in:
|
|
88
|
+
# .python-version
|
|
89
|
+
|
|
90
|
+
# pipenv
|
|
91
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
92
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
93
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
94
|
+
# install all needed dependencies.
|
|
95
|
+
#Pipfile.lock
|
|
96
|
+
|
|
97
|
+
# poetry
|
|
98
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
|
99
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
100
|
+
# commonly ignored for libraries.
|
|
101
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
102
|
+
#poetry.lock
|
|
103
|
+
|
|
104
|
+
# pdm
|
|
105
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
|
106
|
+
#pdm.lock
|
|
107
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
|
108
|
+
# in version control.
|
|
109
|
+
# https://pdm.fming.dev/#use-with-ide
|
|
110
|
+
.pdm.toml
|
|
111
|
+
|
|
112
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
|
113
|
+
__pypackages__/
|
|
114
|
+
|
|
115
|
+
# Celery stuff
|
|
116
|
+
celerybeat-schedule
|
|
117
|
+
celerybeat.pid
|
|
118
|
+
|
|
119
|
+
# SageMath parsed files
|
|
120
|
+
*.sage.py
|
|
121
|
+
|
|
122
|
+
# Environments
|
|
123
|
+
.env
|
|
124
|
+
.venv
|
|
125
|
+
env/
|
|
126
|
+
venv/
|
|
127
|
+
ENV/
|
|
128
|
+
env.bak/
|
|
129
|
+
venv.bak/
|
|
130
|
+
|
|
131
|
+
# Spyder project settings
|
|
132
|
+
.spyderproject
|
|
133
|
+
.spyproject
|
|
134
|
+
|
|
135
|
+
# Rope project settings
|
|
136
|
+
.ropeproject
|
|
137
|
+
|
|
138
|
+
# mkdocs documentation
|
|
139
|
+
/site
|
|
140
|
+
|
|
141
|
+
# mypy
|
|
142
|
+
.mypy_cache/
|
|
143
|
+
.dmypy.json
|
|
144
|
+
dmypy.json
|
|
145
|
+
|
|
146
|
+
# Pyre type checker
|
|
147
|
+
.pyre/
|
|
148
|
+
|
|
149
|
+
# pytype static type analyzer
|
|
150
|
+
.pytype/
|
|
151
|
+
|
|
152
|
+
# Cython debug symbols
|
|
153
|
+
cython_debug/
|
|
154
|
+
|
|
155
|
+
# PyCharm
|
|
156
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
157
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
158
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
|
159
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
160
|
+
#.idea/
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# FlowerTune LLM on $challenge_name Dataset
|
|
2
|
+
|
|
3
|
+
This directory conducts federated instruction tuning with a pretrained [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model on a [$challenge_name dataset](https://huggingface.co/datasets/$dataset_name).
|
|
4
|
+
We use [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the dataset.
|
|
5
|
+
Flower's Simulation Engine is used to simulate the LLM fine-tuning process in federated way,
|
|
6
|
+
which allows users to perform the training on a single GPU.
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
## Methodology
|
|
10
|
+
|
|
11
|
+
This baseline performs federated LLM fine-tuning with [LoRA](https://arxiv.org/pdf/2106.09685) using the [🤗PEFT](https://huggingface.co/docs/peft/en/index) library.
|
|
12
|
+
The clients' models are aggregated with FedAvg strategy.
|
|
13
|
+
This provides a baseline performance for the leaderboard of $challenge_name challenge.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
## Environments setup
|
|
17
|
+
|
|
18
|
+
Project dependencies are defined in `pyproject.toml`. Install them in an activated Python environment with:
|
|
19
|
+
|
|
20
|
+
```shell
|
|
21
|
+
pip install -e .
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
## Experimental setup
|
|
25
|
+
|
|
26
|
+
The dataset is partitioned into $num_clients shards with IID fashion serving as clients.
|
|
27
|
+
We randomly sample $fraction_fit clients to be available for each round,
|
|
28
|
+
and the federated fine-tuning lasts for `200` rounds.
|
|
29
|
+
All settings are defined in `$project_name/conf/static_config.yaml`, which is not allowed to be modified for fair competition if you plan to participated in the [LLM leaderboard](https://flower.ai/benchmarks/llm-leaderboard).
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
## Running the challenge
|
|
33
|
+
|
|
34
|
+
First make sure that you have got the access to [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model with your Hugging-Face account. You can request access directly from the Hugging-Face website.
|
|
35
|
+
Then, follow the instruction [here](https://huggingface.co/docs/huggingface_hub/en/quick-start#login-command) to log in your account. Note you only need to complete this stage once in your development machine:
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
huggingface-cli login
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
Run the challenge with default config values.
|
|
42
|
+
The configs are in `$project_name/conf/config.yaml` and `$project_name/conf/static_config.yaml`, and are loaded automatically.
|
|
43
|
+
|
|
44
|
+
```bash
|
|
45
|
+
flwr run
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
## VRAM consumption
|
|
49
|
+
|
|
50
|
+
We use Mistral-7B model with 4-bit quantization as default. The estimated VRAM consumption per client for each challenge is shown below:
|
|
51
|
+
|
|
52
|
+
| Challenges | GeneralNLP | Finance | Medical | Code |
|
|
53
|
+
| :--------: | :--------: | :--------: | :--------: | :--------: |
|
|
54
|
+
| VRAM | ~25.50 GB | ~17.30 GB | ~22.80 GB | ~17.40 GB |
|
|
55
|
+
|
|
56
|
+
You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which is specified with `flower.engine.simulation` in `pyproject.toml`.
|
|
@@ -1 +1 @@
|
|
|
1
|
-
"""$project_name."""
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
from flwr.client import ClientApp, NumPyClient
|
|
4
|
+
from flwr.common import Context
|
|
5
|
+
from transformers import AutoModelForSequenceClassification
|
|
6
|
+
|
|
7
|
+
from $import_name.task import (
|
|
8
|
+
get_weights,
|
|
9
|
+
load_data,
|
|
10
|
+
set_weights,
|
|
11
|
+
train,
|
|
12
|
+
test,
|
|
13
|
+
CHECKPOINT,
|
|
14
|
+
DEVICE,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Flower client
|
|
19
|
+
class FlowerClient(NumPyClient):
|
|
20
|
+
def __init__(self, net, trainloader, testloader, local_epochs):
|
|
21
|
+
self.net = net
|
|
22
|
+
self.trainloader = trainloader
|
|
23
|
+
self.testloader = testloader
|
|
24
|
+
self.local_epochs = local_epochs
|
|
25
|
+
|
|
26
|
+
def get_parameters(self, config):
|
|
27
|
+
return get_weights(self.net)
|
|
28
|
+
|
|
29
|
+
def set_parameters(self, parameters):
|
|
30
|
+
set_weights(self.net, parameters)
|
|
31
|
+
|
|
32
|
+
def fit(self, parameters, config):
|
|
33
|
+
self.set_parameters(parameters)
|
|
34
|
+
train(
|
|
35
|
+
self.net,
|
|
36
|
+
self.trainloader,
|
|
37
|
+
epochs=self.local_epochs,
|
|
38
|
+
)
|
|
39
|
+
return self.get_parameters(config={}), len(self.trainloader), {}
|
|
40
|
+
|
|
41
|
+
def evaluate(self, parameters, config):
|
|
42
|
+
self.set_parameters(parameters)
|
|
43
|
+
loss, accuracy = test(self.net, self.testloader)
|
|
44
|
+
return float(loss), len(self.testloader), {"accuracy": accuracy}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def client_fn(context: Context):
|
|
48
|
+
# Load model and data
|
|
49
|
+
net = AutoModelForSequenceClassification.from_pretrained(
|
|
50
|
+
CHECKPOINT, num_labels=2
|
|
51
|
+
).to(DEVICE)
|
|
52
|
+
|
|
53
|
+
partition_id = context.node_config["partition-id"]
|
|
54
|
+
num_partitions = context.node_config["num-partitions"]
|
|
55
|
+
trainloader, valloader = load_data(partition_id, num_partitions)
|
|
56
|
+
local_epochs = context.run_config["local-epochs"]
|
|
57
|
+
|
|
58
|
+
# Return Client instance
|
|
59
|
+
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# Flower ClientApp
|
|
63
|
+
app = ClientApp(
|
|
64
|
+
client_fn,
|
|
65
|
+
)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
from flwr.client import NumPyClient, ClientApp
|
|
5
|
+
from flwr.common import Context
|
|
6
|
+
|
|
7
|
+
from $import_name.task import (
|
|
8
|
+
evaluation,
|
|
9
|
+
get_params,
|
|
10
|
+
load_data,
|
|
11
|
+
load_model,
|
|
12
|
+
loss_fn,
|
|
13
|
+
set_params,
|
|
14
|
+
train,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Define Flower Client and client_fn
|
|
19
|
+
class FlowerClient(NumPyClient):
|
|
20
|
+
def __init__(self):
|
|
21
|
+
self.train_x, self.train_y, self.test_x, self.test_y = load_data()
|
|
22
|
+
self.grad_fn = jax.grad(loss_fn)
|
|
23
|
+
model_shape = self.train_x.shape[1:]
|
|
24
|
+
|
|
25
|
+
self.params = load_model(model_shape)
|
|
26
|
+
|
|
27
|
+
def get_parameters(self, config):
|
|
28
|
+
return get_params(self.params)
|
|
29
|
+
|
|
30
|
+
def set_parameters(self, parameters):
|
|
31
|
+
set_params(self.params, parameters)
|
|
32
|
+
|
|
33
|
+
def fit(self, parameters, config):
|
|
34
|
+
self.set_parameters(parameters)
|
|
35
|
+
self.params, loss, num_examples = train(
|
|
36
|
+
self.params, self.grad_fn, self.train_x, self.train_y
|
|
37
|
+
)
|
|
38
|
+
parameters = self.get_parameters(config={})
|
|
39
|
+
return parameters, num_examples, {"loss": float(loss)}
|
|
40
|
+
|
|
41
|
+
def evaluate(self, parameters, config):
|
|
42
|
+
self.set_parameters(parameters)
|
|
43
|
+
loss, num_examples = evaluation(
|
|
44
|
+
self.params, self.grad_fn, self.test_x, self.test_y
|
|
45
|
+
)
|
|
46
|
+
return float(loss), num_examples, {"loss": float(loss)}
|
|
47
|
+
|
|
48
|
+
def client_fn(context: Context):
|
|
49
|
+
# Return Client instance
|
|
50
|
+
return FlowerClient().to_client()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# Flower ClientApp
|
|
54
|
+
app = ClientApp(
|
|
55
|
+
client_fn,
|
|
56
|
+
)
|