flwr-nightly 1.10.0.dev20240619__py3-none-any.whl → 1.10.0.dev20240621__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +3 -0
- flwr/cli/build.py +3 -7
- flwr/cli/new/new.py +104 -28
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +86 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +124 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +42 -0
- flwr/cli/run/run.py +8 -1
- flwr/client/client_app.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/connection.py +1 -1
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/client/supernode/app.py +1 -1
- flwr/common/address.py +1 -1
- flwr/common/config.py +8 -6
- flwr/common/constant.py +1 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +1 -1
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/version.py +14 -0
- flwr/server/compat/app.py +1 -1
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +1 -1
- flwr/server/driver/driver.py +6 -0
- flwr/server/driver/grpc_driver.py +85 -63
- flwr/server/driver/inmemory_driver.py +28 -26
- flwr/server/run_serverapp.py +61 -18
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +15 -3
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +44 -25
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +1 -1
- flwr/server/superlink/state/sqlite_state.py +1 -1
- flwr/server/superlink/state/state.py +1 -1
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/simulation/__init__.py +1 -1
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +0 -6
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +47 -28
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240621.dist-info}/METADATA +2 -1
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240621.dist-info}/RECORD +98 -88
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240621.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240621.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240621.dist-info}/entry_points.txt +0 -0
flwr/cli/app.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Flower command line interface."""
|
|
16
16
|
|
|
17
17
|
import typer
|
|
18
|
+
from typer.main import get_command
|
|
18
19
|
|
|
19
20
|
from .build import build
|
|
20
21
|
from .example import example
|
|
@@ -37,5 +38,7 @@ app.command()(run)
|
|
|
37
38
|
app.command()(build)
|
|
38
39
|
app.command()(install)
|
|
39
40
|
|
|
41
|
+
typer_click_object = get_command(app)
|
|
42
|
+
|
|
40
43
|
if __name__ == "__main__":
|
|
41
44
|
app()
|
flwr/cli/build.py
CHANGED
|
@@ -36,13 +36,9 @@ def build(
|
|
|
36
36
|
) -> str:
|
|
37
37
|
"""Build a Flower project into a Flower App Bundle (FAB).
|
|
38
38
|
|
|
39
|
-
You can run
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
You can also build a specific directory:
|
|
44
|
-
|
|
45
|
-
`flwr build --directory ./projects/flower-hello-world`
|
|
39
|
+
You can run ``flwr build`` without any arguments to bundle the current directory,
|
|
40
|
+
or you can use ``--directory`` to build a specific directory:
|
|
41
|
+
``flwr build --directory ./projects/flower-hello-world``.
|
|
46
42
|
"""
|
|
47
43
|
if directory is None:
|
|
48
44
|
directory = Path.cwd()
|
flwr/cli/new/new.py
CHANGED
|
@@ -41,6 +41,16 @@ class MlFramework(str, Enum):
|
|
|
41
41
|
HUGGINGFACE = "HF"
|
|
42
42
|
MLX = "MLX"
|
|
43
43
|
SKLEARN = "sklearn"
|
|
44
|
+
FLOWERTUNE = "FlowerTune"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class LlmChallengeName(str, Enum):
|
|
48
|
+
"""Available LLM challenges."""
|
|
49
|
+
|
|
50
|
+
GENERALNLP = "GeneralNLP"
|
|
51
|
+
FINANCE = "Finance"
|
|
52
|
+
MEDICAL = "Medical"
|
|
53
|
+
CODE = "Code"
|
|
44
54
|
|
|
45
55
|
|
|
46
56
|
class TemplateNotFound(Exception):
|
|
@@ -81,6 +91,7 @@ def render_and_create(file_path: str, template: str, context: Dict[str, str]) ->
|
|
|
81
91
|
create_file(file_path, content)
|
|
82
92
|
|
|
83
93
|
|
|
94
|
+
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
|
|
84
95
|
def new(
|
|
85
96
|
project_name: Annotated[
|
|
86
97
|
Optional[str],
|
|
@@ -125,6 +136,19 @@ def new(
|
|
|
125
136
|
|
|
126
137
|
framework_str = framework_str.lower()
|
|
127
138
|
|
|
139
|
+
if framework_str == "flowertune":
|
|
140
|
+
llm_challenge_value = prompt_options(
|
|
141
|
+
"Please select LLM challenge by typing in the number",
|
|
142
|
+
sorted([challenge.value for challenge in LlmChallengeName]),
|
|
143
|
+
)
|
|
144
|
+
selected_value = [
|
|
145
|
+
name
|
|
146
|
+
for name, value in vars(LlmChallengeName).items()
|
|
147
|
+
if value == llm_challenge_value
|
|
148
|
+
]
|
|
149
|
+
llm_challenge_str = selected_value[0]
|
|
150
|
+
llm_challenge_str = llm_challenge_str.lower()
|
|
151
|
+
|
|
128
152
|
print(
|
|
129
153
|
typer.style(
|
|
130
154
|
f"\n🔨 Creating Flower project {project_name}...",
|
|
@@ -139,33 +163,6 @@ def new(
|
|
|
139
163
|
import_name = package_name.replace("-", "_")
|
|
140
164
|
project_dir = os.path.join(cwd, package_name)
|
|
141
165
|
|
|
142
|
-
# List of files to render
|
|
143
|
-
files = {
|
|
144
|
-
".gitignore": {"template": "app/.gitignore.tpl"},
|
|
145
|
-
"README.md": {"template": "app/README.md.tpl"},
|
|
146
|
-
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
|
|
147
|
-
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
|
|
148
|
-
f"{import_name}/server.py": {
|
|
149
|
-
"template": f"app/code/server.{framework_str}.py.tpl"
|
|
150
|
-
},
|
|
151
|
-
f"{import_name}/client.py": {
|
|
152
|
-
"template": f"app/code/client.{framework_str}.py.tpl"
|
|
153
|
-
},
|
|
154
|
-
}
|
|
155
|
-
|
|
156
|
-
# Depending on the framework, generate task.py file
|
|
157
|
-
frameworks_with_tasks = [
|
|
158
|
-
MlFramework.PYTORCH.value.lower(),
|
|
159
|
-
MlFramework.JAX.value.lower(),
|
|
160
|
-
MlFramework.HUGGINGFACE.value.lower(),
|
|
161
|
-
MlFramework.MLX.value.lower(),
|
|
162
|
-
MlFramework.TENSORFLOW.value.lower(),
|
|
163
|
-
]
|
|
164
|
-
if framework_str in frameworks_with_tasks:
|
|
165
|
-
files[f"{import_name}/task.py"] = {
|
|
166
|
-
"template": f"app/code/task.{framework_str}.py.tpl"
|
|
167
|
-
}
|
|
168
|
-
|
|
169
166
|
context = {
|
|
170
167
|
"project_name": project_name,
|
|
171
168
|
"package_name": package_name,
|
|
@@ -173,6 +170,85 @@ def new(
|
|
|
173
170
|
"username": username,
|
|
174
171
|
}
|
|
175
172
|
|
|
173
|
+
# List of files to render
|
|
174
|
+
if framework_str == "flowertune":
|
|
175
|
+
files = {
|
|
176
|
+
".gitignore": {"template": "app/.gitignore.tpl"},
|
|
177
|
+
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
|
|
178
|
+
"README.md": {"template": f"app/README.{framework_str}.md.tpl"},
|
|
179
|
+
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
|
|
180
|
+
f"{import_name}/server.py": {
|
|
181
|
+
"template": "app/code/flwr_tune/server.py.tpl"
|
|
182
|
+
},
|
|
183
|
+
f"{import_name}/client.py": {
|
|
184
|
+
"template": "app/code/flwr_tune/client.py.tpl"
|
|
185
|
+
},
|
|
186
|
+
f"{import_name}/app.py": {"template": "app/code/flwr_tune/app.py.tpl"},
|
|
187
|
+
f"{import_name}/models.py": {
|
|
188
|
+
"template": "app/code/flwr_tune/models.py.tpl"
|
|
189
|
+
},
|
|
190
|
+
f"{import_name}/dataset.py": {
|
|
191
|
+
"template": "app/code/flwr_tune/dataset.py.tpl"
|
|
192
|
+
},
|
|
193
|
+
f"{import_name}/conf/config.yaml": {
|
|
194
|
+
"template": "app/code/flwr_tune/config.yaml.tpl"
|
|
195
|
+
},
|
|
196
|
+
f"{import_name}/conf/static_config.yaml": {
|
|
197
|
+
"template": "app/code/flwr_tune/static_config.yaml.tpl"
|
|
198
|
+
},
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
# Challenge specific context
|
|
202
|
+
fraction_fit = "0.2" if llm_challenge_str == "code" else "0.1"
|
|
203
|
+
if llm_challenge_str == "generalnlp":
|
|
204
|
+
challenge_name = "General NLP"
|
|
205
|
+
num_clients = "20"
|
|
206
|
+
dataset_name = "vicgalle/alpaca-gpt4"
|
|
207
|
+
elif llm_challenge_str == "finance":
|
|
208
|
+
challenge_name = "Finance"
|
|
209
|
+
num_clients = "50"
|
|
210
|
+
dataset_name = "FinGPT/fingpt-sentiment-train"
|
|
211
|
+
elif llm_challenge_str == "medical":
|
|
212
|
+
challenge_name = "Medical"
|
|
213
|
+
num_clients = "20"
|
|
214
|
+
dataset_name = "medalpaca/medical_meadow_medical_flashcards"
|
|
215
|
+
else:
|
|
216
|
+
challenge_name = "Code"
|
|
217
|
+
num_clients = "10"
|
|
218
|
+
dataset_name = "lucasmccabe-lmi/CodeAlpaca-20k"
|
|
219
|
+
|
|
220
|
+
context["llm_challenge_str"] = llm_challenge_str
|
|
221
|
+
context["fraction_fit"] = fraction_fit
|
|
222
|
+
context["challenge_name"] = challenge_name
|
|
223
|
+
context["num_clients"] = num_clients
|
|
224
|
+
context["dataset_name"] = dataset_name
|
|
225
|
+
else:
|
|
226
|
+
files = {
|
|
227
|
+
".gitignore": {"template": "app/.gitignore.tpl"},
|
|
228
|
+
"README.md": {"template": "app/README.md.tpl"},
|
|
229
|
+
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
|
|
230
|
+
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
|
|
231
|
+
f"{import_name}/server.py": {
|
|
232
|
+
"template": f"app/code/server.{framework_str}.py.tpl"
|
|
233
|
+
},
|
|
234
|
+
f"{import_name}/client.py": {
|
|
235
|
+
"template": f"app/code/client.{framework_str}.py.tpl"
|
|
236
|
+
},
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
# Depending on the framework, generate task.py file
|
|
240
|
+
frameworks_with_tasks = [
|
|
241
|
+
MlFramework.PYTORCH.value.lower(),
|
|
242
|
+
MlFramework.JAX.value.lower(),
|
|
243
|
+
MlFramework.HUGGINGFACE.value.lower(),
|
|
244
|
+
MlFramework.MLX.value.lower(),
|
|
245
|
+
MlFramework.TENSORFLOW.value.lower(),
|
|
246
|
+
]
|
|
247
|
+
if framework_str in frameworks_with_tasks:
|
|
248
|
+
files[f"{import_name}/task.py"] = {
|
|
249
|
+
"template": f"app/code/task.{framework_str}.py.tpl"
|
|
250
|
+
}
|
|
251
|
+
|
|
176
252
|
for file_path, value in files.items():
|
|
177
253
|
render_and_create(
|
|
178
254
|
file_path=os.path.join(project_dir, file_path),
|
|
@@ -190,7 +266,7 @@ def new(
|
|
|
190
266
|
)
|
|
191
267
|
print(
|
|
192
268
|
typer.style(
|
|
193
|
-
f" cd {
|
|
269
|
+
f" cd {package_name}\n" + " pip install -e .\n flwr run\n",
|
|
194
270
|
fg=typer.colors.BRIGHT_CYAN,
|
|
195
271
|
bold=True,
|
|
196
272
|
)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# FlowerTune LLM on $challenge_name Dataset
|
|
2
|
+
|
|
3
|
+
This directory conducts federated instruction tuning with a pretrained [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model on a [$challenge_name dataset](https://huggingface.co/datasets/$dataset_name).
|
|
4
|
+
We use [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the dataset.
|
|
5
|
+
Flower's Simulation Engine is used to simulate the LLM fine-tuning process in federated way,
|
|
6
|
+
which allows users to perform the training on a single GPU.
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
## Methodology
|
|
10
|
+
|
|
11
|
+
This baseline performs federated LLM fine-tuning with [LoRA](https://arxiv.org/pdf/2106.09685) using the [🤗PEFT](https://huggingface.co/docs/peft/en/index) library.
|
|
12
|
+
The clients' models are aggregated with FedAvg strategy.
|
|
13
|
+
This provides a baseline performance for the leaderboard of $challenge_name challenge.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
## Environments setup
|
|
17
|
+
|
|
18
|
+
Project dependencies are defined in `pyproject.toml`. Install them in an activated Python environment with:
|
|
19
|
+
|
|
20
|
+
```shell
|
|
21
|
+
pip install -e .
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
## Experimental setup
|
|
25
|
+
|
|
26
|
+
The dataset is partitioned into $num_clients shards with IID fashion serving as clients.
|
|
27
|
+
We randomly sample $fraction_fit clients to be available for each round,
|
|
28
|
+
and the federated fine-tuning lasts for `200` rounds.
|
|
29
|
+
All settings are defined in `$project_name/conf/static_config.yaml`, which is not allowed to be modified for fair competition if you plan to participated in the [LLM leaderboard](https://flower.ai/benchmarks/llm-leaderboard).
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
## Running the challenge
|
|
33
|
+
|
|
34
|
+
First make sure that you have got the access to [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model with your Hugging-Face account. You can request access directly from the Hugging-Face website.
|
|
35
|
+
Then, follow the instruction [here](https://huggingface.co/docs/huggingface_hub/en/quick-start#login-command) to log in your account. Note you only need to complete this stage once in your development machine:
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
huggingface-cli login
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
Run the challenge with default config values.
|
|
42
|
+
The configs are in `$project_name/conf/config.yaml` and `$project_name/conf/static_config.yaml`, and are loaded automatically.
|
|
43
|
+
|
|
44
|
+
```bash
|
|
45
|
+
flwr run
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
## VRAM consumption
|
|
49
|
+
|
|
50
|
+
We use Mistral-7B model with 4-bit quantization as default. The estimated VRAM consumption per client for each challenge is shown below:
|
|
51
|
+
|
|
52
|
+
| Challenges | GeneralNLP | Finance | Medical | Code |
|
|
53
|
+
| :--------: | :--------: | :--------: | :--------: | :--------: |
|
|
54
|
+
| VRAM | ~25.50 GB | ~17.30 GB | ~22.80 GB | ~17.40 GB |
|
|
55
|
+
|
|
56
|
+
You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which is specified with `flower.engine.simulation` in `pyproject.toml`.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower CLI `new` command app / code / flwr_tune templates."""
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import warnings
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
|
|
7
|
+
from flwr_datasets import FederatedDataset
|
|
8
|
+
from hydra import compose, initialize
|
|
9
|
+
from hydra.utils import instantiate
|
|
10
|
+
|
|
11
|
+
from flwr.client import ClientApp
|
|
12
|
+
from flwr.common import ndarrays_to_parameters
|
|
13
|
+
from flwr.server import ServerApp, ServerConfig
|
|
14
|
+
|
|
15
|
+
from $import_name.client import gen_client_fn, get_parameters
|
|
16
|
+
from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
|
|
17
|
+
from $import_name.models import get_model
|
|
18
|
+
from $import_name.server import fit_weighted_average, get_evaluate_fn, get_on_fit_config
|
|
19
|
+
|
|
20
|
+
# Avoid warnings
|
|
21
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
22
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
23
|
+
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
|
|
24
|
+
|
|
25
|
+
# Initialise regular config
|
|
26
|
+
with initialize(config_path="conf", version_base="1.1"):
|
|
27
|
+
cfg = compose(config_name="config")
|
|
28
|
+
|
|
29
|
+
# Initialise static config
|
|
30
|
+
with initialize(config_path="conf", version_base="1.1"):
|
|
31
|
+
cfg_static = compose(config_name="static_config")
|
|
32
|
+
|
|
33
|
+
cfg.train.num_rounds = cfg_static.num_rounds
|
|
34
|
+
|
|
35
|
+
# Create output directory given current timestamp
|
|
36
|
+
current_time = datetime.now()
|
|
37
|
+
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
|
|
38
|
+
save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
|
|
39
|
+
os.makedirs(save_path, exist_ok=True)
|
|
40
|
+
|
|
41
|
+
# Partition dataset and get dataloaders
|
|
42
|
+
partitioner = instantiate(cfg_static.partitioner)
|
|
43
|
+
fds = FederatedDataset(
|
|
44
|
+
dataset=cfg_static.dataset.name, partitioners={"train": partitioner}
|
|
45
|
+
)
|
|
46
|
+
(
|
|
47
|
+
tokenizer,
|
|
48
|
+
data_collator,
|
|
49
|
+
formatting_prompts_func,
|
|
50
|
+
) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)
|
|
51
|
+
|
|
52
|
+
# ClientApp for Flower Next
|
|
53
|
+
client = ClientApp(
|
|
54
|
+
client_fn=gen_client_fn(
|
|
55
|
+
fds,
|
|
56
|
+
tokenizer,
|
|
57
|
+
formatting_prompts_func,
|
|
58
|
+
data_collator,
|
|
59
|
+
cfg.model,
|
|
60
|
+
cfg.train,
|
|
61
|
+
save_path,
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Get initial model weights
|
|
66
|
+
init_model = get_model(cfg.model)
|
|
67
|
+
init_model_parameters = get_parameters(init_model)
|
|
68
|
+
init_model_parameters = ndarrays_to_parameters(init_model_parameters)
|
|
69
|
+
|
|
70
|
+
# Instantiate strategy according to config. Here we pass other arguments
|
|
71
|
+
# that are only defined at runtime.
|
|
72
|
+
strategy = instantiate(
|
|
73
|
+
cfg.strategy,
|
|
74
|
+
on_fit_config_fn=get_on_fit_config(),
|
|
75
|
+
fit_metrics_aggregation_fn=fit_weighted_average,
|
|
76
|
+
initial_parameters=init_model_parameters,
|
|
77
|
+
evaluate_fn=get_evaluate_fn(
|
|
78
|
+
cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
|
|
79
|
+
),
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# ServerApp for Flower Next
|
|
83
|
+
server = ServerApp(
|
|
84
|
+
config=ServerConfig(num_rounds=cfg_static.num_rounds),
|
|
85
|
+
strategy=strategy,
|
|
86
|
+
)
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
from typing import Callable, Dict, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from omegaconf import DictConfig
|
|
8
|
+
from peft import get_peft_model_state_dict, set_peft_model_state_dict
|
|
9
|
+
from transformers import TrainingArguments
|
|
10
|
+
from trl import SFTTrainer
|
|
11
|
+
|
|
12
|
+
from flwr.client import NumPyClient
|
|
13
|
+
from flwr.common.typing import NDArrays, Scalar
|
|
14
|
+
from $import_name.dataset import reformat
|
|
15
|
+
from $import_name.models import cosine_annealing, get_model
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# pylint: disable=too-many-arguments
|
|
19
|
+
# pylint: disable=too-many-instance-attributes
|
|
20
|
+
class FlowerClient(NumPyClient):
|
|
21
|
+
"""Standard Flower client for CNN training."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
model_cfg: DictConfig,
|
|
26
|
+
train_cfg: DictConfig,
|
|
27
|
+
trainset,
|
|
28
|
+
tokenizer,
|
|
29
|
+
formatting_prompts_func,
|
|
30
|
+
data_collator,
|
|
31
|
+
save_path,
|
|
32
|
+
): # pylint: disable=too-many-arguments
|
|
33
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
34
|
+
self.train_cfg = train_cfg
|
|
35
|
+
self.training_argumnets = TrainingArguments(**train_cfg.training_arguments)
|
|
36
|
+
self.tokenizer = tokenizer
|
|
37
|
+
self.formatting_prompts_func = formatting_prompts_func
|
|
38
|
+
self.data_collator = data_collator
|
|
39
|
+
self.save_path = save_path
|
|
40
|
+
|
|
41
|
+
# instantiate model
|
|
42
|
+
self.model = get_model(model_cfg)
|
|
43
|
+
|
|
44
|
+
self.trainset = trainset
|
|
45
|
+
|
|
46
|
+
def fit(
|
|
47
|
+
self, parameters: NDArrays, config: Dict[str, Scalar]
|
|
48
|
+
) -> Tuple[NDArrays, int, Dict]:
|
|
49
|
+
"""Implement distributed fit function for a given client."""
|
|
50
|
+
set_parameters(self.model, parameters)
|
|
51
|
+
|
|
52
|
+
new_lr = cosine_annealing(
|
|
53
|
+
int(config["current_round"]),
|
|
54
|
+
self.train_cfg.num_rounds,
|
|
55
|
+
self.train_cfg.learning_rate_max,
|
|
56
|
+
self.train_cfg.learning_rate_min,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
self.training_argumnets.learning_rate = new_lr
|
|
60
|
+
self.training_argumnets.output_dir = self.save_path
|
|
61
|
+
|
|
62
|
+
# Construct trainer
|
|
63
|
+
trainer = SFTTrainer(
|
|
64
|
+
model=self.model,
|
|
65
|
+
tokenizer=self.tokenizer,
|
|
66
|
+
args=self.training_argumnets,
|
|
67
|
+
max_seq_length=self.train_cfg.seq_length,
|
|
68
|
+
train_dataset=self.trainset,
|
|
69
|
+
formatting_func=self.formatting_prompts_func,
|
|
70
|
+
data_collator=self.data_collator,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Do local training
|
|
74
|
+
results = trainer.train()
|
|
75
|
+
|
|
76
|
+
return (
|
|
77
|
+
get_parameters(self.model),
|
|
78
|
+
len(self.trainset),
|
|
79
|
+
{"train_loss": results.training_loss},
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def set_parameters(model, parameters: NDArrays) -> None:
|
|
84
|
+
"""Change the parameters of the model using the given ones."""
|
|
85
|
+
peft_state_dict_keys = get_peft_model_state_dict(model).keys()
|
|
86
|
+
params_dict = zip(peft_state_dict_keys, parameters)
|
|
87
|
+
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
|
|
88
|
+
set_peft_model_state_dict(model, state_dict)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_parameters(model) -> NDArrays:
|
|
92
|
+
"""Return the parameters of the current net."""
|
|
93
|
+
state_dict = get_peft_model_state_dict(model)
|
|
94
|
+
return [val.cpu().numpy() for _, val in state_dict.items()]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def gen_client_fn(
|
|
98
|
+
fds,
|
|
99
|
+
tokenizer,
|
|
100
|
+
formatting_prompts_func,
|
|
101
|
+
data_collator,
|
|
102
|
+
model_cfg: DictConfig,
|
|
103
|
+
train_cfg: DictConfig,
|
|
104
|
+
save_path: str,
|
|
105
|
+
) -> Callable[[str], FlowerClient]: # pylint: disable=too-many-arguments
|
|
106
|
+
"""Generate the client function that creates the Flower Clients."""
|
|
107
|
+
|
|
108
|
+
def client_fn(cid: str) -> FlowerClient:
|
|
109
|
+
"""Create a Flower client representing a single organization."""
|
|
110
|
+
# Let's get the partition corresponding to the i-th client
|
|
111
|
+
client_trainset = fds.load_partition(int(cid), "train")
|
|
112
|
+
client_trainset = reformat(client_trainset, llm_task="$llm_challenge_str")
|
|
113
|
+
|
|
114
|
+
return FlowerClient(
|
|
115
|
+
model_cfg,
|
|
116
|
+
train_cfg,
|
|
117
|
+
client_trainset,
|
|
118
|
+
tokenizer,
|
|
119
|
+
formatting_prompts_func,
|
|
120
|
+
data_collator,
|
|
121
|
+
save_path,
|
|
122
|
+
).to_client()
|
|
123
|
+
|
|
124
|
+
return client_fn
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# Federated Instruction Tuning
|
|
2
|
+
---
|
|
3
|
+
model:
|
|
4
|
+
name: "mistralai/Mistral-7B-v0.3"
|
|
5
|
+
quantization: 4 # 8 or 4 if you want to do quantization with BitsAndBytes
|
|
6
|
+
gradient_checkpointing: True
|
|
7
|
+
lora:
|
|
8
|
+
peft_lora_r: 32
|
|
9
|
+
peft_lora_alpha: 64
|
|
10
|
+
|
|
11
|
+
train:
|
|
12
|
+
num_rounds: null
|
|
13
|
+
save_every_round: 5
|
|
14
|
+
learning_rate_max: 5e-5
|
|
15
|
+
learning_rate_min: 1e-6
|
|
16
|
+
seq_length: 512
|
|
17
|
+
training_arguments:
|
|
18
|
+
output_dir: null # to be set by hydra
|
|
19
|
+
learning_rate: null # to be set by the client
|
|
20
|
+
per_device_train_batch_size: 16
|
|
21
|
+
gradient_accumulation_steps: 1
|
|
22
|
+
logging_steps: 10
|
|
23
|
+
num_train_epochs: 3
|
|
24
|
+
max_steps: 10
|
|
25
|
+
report_to: null
|
|
26
|
+
save_steps: 1000
|
|
27
|
+
save_total_limit: 10
|
|
28
|
+
gradient_checkpointing: True
|
|
29
|
+
lr_scheduler_type: "constant"
|
|
30
|
+
|
|
31
|
+
strategy:
|
|
32
|
+
_target_: flwr.server.strategy.FedAvg
|
|
33
|
+
fraction_fit: $fraction_fit
|
|
34
|
+
fraction_evaluate: 0.0 # no client evaluation
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
from transformers import AutoTokenizer
|
|
4
|
+
from trl import DataCollatorForCompletionOnlyLM
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def formatting_prompts_func(example):
|
|
8
|
+
"""Construct prompts."""
|
|
9
|
+
output_texts = []
|
|
10
|
+
# Constructing a standard Alpaca
|
|
11
|
+
# (https://github.com/tatsu-lab/stanford_alpaca#data-release) prompt
|
|
12
|
+
mssg = (
|
|
13
|
+
"Below is an instruction that describes a task. "
|
|
14
|
+
"Write a response that appropriately completes the request."
|
|
15
|
+
)
|
|
16
|
+
for i in range(len(example["instruction"])):
|
|
17
|
+
text = (
|
|
18
|
+
f"{mssg}\n### Instruction:\n{example['instruction'][i]}\n"
|
|
19
|
+
f"### Response: {example['response'][i]}"
|
|
20
|
+
)
|
|
21
|
+
output_texts.append(text)
|
|
22
|
+
return output_texts
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_tokenizer_and_data_collator_and_propt_formatting(model_name: str):
|
|
26
|
+
"""Get tokenizer, data_collator and prompt formatting."""
|
|
27
|
+
# From: https://huggingface.co/docs/trl/en/sft_trainer
|
|
28
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
29
|
+
model_name, use_fast=True, padding_side="right"
|
|
30
|
+
)
|
|
31
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
32
|
+
response_template_with_context = "\n### Response:" # alpaca response tag
|
|
33
|
+
response_template_ids = tokenizer.encode(
|
|
34
|
+
response_template_with_context, add_special_tokens=False
|
|
35
|
+
)[2:]
|
|
36
|
+
data_collator = DataCollatorForCompletionOnlyLM(
|
|
37
|
+
response_template_ids, tokenizer=tokenizer
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
return tokenizer, data_collator, formatting_prompts_func
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def formatting(dataset):
|
|
44
|
+
"""Format dataset."""
|
|
45
|
+
dataset["instruction"] = dataset["instruction"] + " " + dataset["input"]
|
|
46
|
+
return dataset
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def reformat(dataset, llm_task):
|
|
50
|
+
"""Reformat datasets."""
|
|
51
|
+
dataset = dataset.rename_column("output", "response")
|
|
52
|
+
if llm_task == "finance" or llm_task == "code":
|
|
53
|
+
dataset = dataset.map(formatting, remove_columns=["input"])
|
|
54
|
+
if llm_task == "medical":
|
|
55
|
+
dataset = dataset.remove_columns(["instruction"])
|
|
56
|
+
dataset = dataset.rename_column("input", "instruction")
|
|
57
|
+
return dataset
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
from peft import LoraConfig, get_peft_model
|
|
8
|
+
from peft.utils import prepare_model_for_kbit_training
|
|
9
|
+
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def cosine_annealing(
|
|
13
|
+
current_round: int,
|
|
14
|
+
total_round: int,
|
|
15
|
+
lrate_max: float = 0.001,
|
|
16
|
+
lrate_min: float = 0.0,
|
|
17
|
+
) -> float:
|
|
18
|
+
"""Implement cosine annealing learning rate schedule."""
|
|
19
|
+
cos_inner = math.pi * current_round / total_round
|
|
20
|
+
return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_model(model_cfg: DictConfig):
|
|
24
|
+
"""Load model with appropriate quantization config and other optimizations.
|
|
25
|
+
|
|
26
|
+
Please refer to this example for `peft + BitsAndBytes`:
|
|
27
|
+
https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py
|
|
28
|
+
"""
|
|
29
|
+
if model_cfg.quantization == 4:
|
|
30
|
+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
|
31
|
+
elif model_cfg.quantization == 8:
|
|
32
|
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
39
|
+
model_cfg.name,
|
|
40
|
+
quantization_config=quantization_config,
|
|
41
|
+
torch_dtype=torch.bfloat16,
|
|
42
|
+
low_cpu_mem_usage=True,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
model = prepare_model_for_kbit_training(
|
|
46
|
+
model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
peft_config = LoraConfig(
|
|
50
|
+
r=model_cfg.lora.peft_lora_r,
|
|
51
|
+
lora_alpha=model_cfg.lora.peft_lora_alpha,
|
|
52
|
+
lora_dropout=0.075,
|
|
53
|
+
task_type="CAUSAL_LM",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if model_cfg.gradient_checkpointing:
|
|
57
|
+
model.config.use_cache = False
|
|
58
|
+
|
|
59
|
+
return get_peft_model(model, peft_config)
|