flwr-nightly 1.10.0.dev20240619__py3-none-any.whl → 1.10.0.dev20240624__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (98) hide show
  1. flwr/cli/app.py +3 -0
  2. flwr/cli/build.py +3 -7
  3. flwr/cli/new/new.py +104 -28
  4. flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
  5. flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
  6. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +86 -0
  7. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +124 -0
  8. flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
  9. flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
  10. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
  11. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
  12. flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
  13. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +42 -0
  14. flwr/cli/run/run.py +8 -1
  15. flwr/client/client_app.py +1 -1
  16. flwr/client/dpfedavg_numpy_client.py +1 -1
  17. flwr/client/grpc_rere_client/__init__.py +1 -1
  18. flwr/client/grpc_rere_client/connection.py +1 -1
  19. flwr/client/message_handler/__init__.py +1 -1
  20. flwr/client/message_handler/message_handler.py +1 -1
  21. flwr/client/mod/__init__.py +1 -1
  22. flwr/client/mod/secure_aggregation/__init__.py +1 -1
  23. flwr/client/mod/utils.py +1 -1
  24. flwr/client/rest_client/__init__.py +1 -1
  25. flwr/client/rest_client/connection.py +1 -1
  26. flwr/client/supernode/app.py +1 -1
  27. flwr/common/address.py +1 -1
  28. flwr/common/config.py +8 -6
  29. flwr/common/constant.py +1 -1
  30. flwr/common/date.py +1 -1
  31. flwr/common/dp.py +1 -1
  32. flwr/common/grpc.py +1 -1
  33. flwr/common/secure_aggregation/__init__.py +1 -1
  34. flwr/common/secure_aggregation/crypto/__init__.py +1 -1
  35. flwr/common/secure_aggregation/crypto/shamir.py +1 -1
  36. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
  37. flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
  38. flwr/common/secure_aggregation/quantization.py +1 -1
  39. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  40. flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
  41. flwr/common/version.py +14 -0
  42. flwr/server/compat/app.py +1 -1
  43. flwr/server/compat/app_utils.py +1 -1
  44. flwr/server/compat/driver_client_proxy.py +1 -1
  45. flwr/server/driver/driver.py +6 -0
  46. flwr/server/driver/grpc_driver.py +85 -63
  47. flwr/server/driver/inmemory_driver.py +28 -26
  48. flwr/server/run_serverapp.py +61 -18
  49. flwr/server/strategy/bulyan.py +1 -1
  50. flwr/server/strategy/dpfedavg_adaptive.py +1 -1
  51. flwr/server/strategy/dpfedavg_fixed.py +1 -1
  52. flwr/server/strategy/fedadagrad.py +1 -1
  53. flwr/server/strategy/fedadam.py +1 -1
  54. flwr/server/strategy/fedavg_android.py +1 -1
  55. flwr/server/strategy/fedavgm.py +1 -1
  56. flwr/server/strategy/fedmedian.py +1 -1
  57. flwr/server/strategy/fedopt.py +1 -1
  58. flwr/server/strategy/fedprox.py +1 -1
  59. flwr/server/strategy/fedxgb_bagging.py +1 -1
  60. flwr/server/strategy/fedxgb_cyclic.py +1 -1
  61. flwr/server/strategy/fedxgb_nn_avg.py +1 -1
  62. flwr/server/strategy/fedyogi.py +1 -1
  63. flwr/server/strategy/krum.py +1 -1
  64. flwr/server/strategy/qfedavg.py +1 -1
  65. flwr/server/superlink/driver/__init__.py +1 -1
  66. flwr/server/superlink/driver/driver_grpc.py +1 -1
  67. flwr/server/superlink/driver/driver_servicer.py +15 -3
  68. flwr/server/superlink/fleet/__init__.py +1 -1
  69. flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
  70. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
  71. flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
  72. flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
  73. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -1
  74. flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  76. flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
  77. flwr/server/superlink/fleet/message_handler/message_handler.py +1 -1
  78. flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
  79. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
  80. flwr/server/superlink/fleet/vce/backend/raybackend.py +44 -25
  81. flwr/server/superlink/state/__init__.py +1 -1
  82. flwr/server/superlink/state/in_memory_state.py +1 -1
  83. flwr/server/superlink/state/sqlite_state.py +1 -1
  84. flwr/server/superlink/state/state.py +1 -1
  85. flwr/server/superlink/state/state_factory.py +11 -2
  86. flwr/server/utils/__init__.py +1 -1
  87. flwr/server/utils/tensorboard.py +1 -1
  88. flwr/simulation/__init__.py +1 -1
  89. flwr/simulation/app.py +1 -1
  90. flwr/simulation/ray_transport/__init__.py +1 -1
  91. flwr/simulation/ray_transport/ray_actor.py +0 -6
  92. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  93. flwr/simulation/run_simulation.py +47 -28
  94. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/METADATA +2 -1
  95. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/RECORD +98 -88
  96. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/LICENSE +0 -0
  97. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/WHEEL +0 -0
  98. {flwr_nightly-1.10.0.dev20240619.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/entry_points.txt +0 -0
flwr/cli/app.py CHANGED
@@ -15,6 +15,7 @@
15
15
  """Flower command line interface."""
16
16
 
17
17
  import typer
18
+ from typer.main import get_command
18
19
 
19
20
  from .build import build
20
21
  from .example import example
@@ -37,5 +38,7 @@ app.command()(run)
37
38
  app.command()(build)
38
39
  app.command()(install)
39
40
 
41
+ typer_click_object = get_command(app)
42
+
40
43
  if __name__ == "__main__":
41
44
  app()
flwr/cli/build.py CHANGED
@@ -36,13 +36,9 @@ def build(
36
36
  ) -> str:
37
37
  """Build a Flower project into a Flower App Bundle (FAB).
38
38
 
39
- You can run `flwr build` without any argument to bundle the current directory:
40
-
41
- `flwr build`
42
-
43
- You can also build a specific directory:
44
-
45
- `flwr build --directory ./projects/flower-hello-world`
39
+ You can run ``flwr build`` without any arguments to bundle the current directory,
40
+ or you can use ``--directory`` to build a specific directory:
41
+ ``flwr build --directory ./projects/flower-hello-world``.
46
42
  """
47
43
  if directory is None:
48
44
  directory = Path.cwd()
flwr/cli/new/new.py CHANGED
@@ -41,6 +41,16 @@ class MlFramework(str, Enum):
41
41
  HUGGINGFACE = "HF"
42
42
  MLX = "MLX"
43
43
  SKLEARN = "sklearn"
44
+ FLOWERTUNE = "FlowerTune"
45
+
46
+
47
+ class LlmChallengeName(str, Enum):
48
+ """Available LLM challenges."""
49
+
50
+ GENERALNLP = "GeneralNLP"
51
+ FINANCE = "Finance"
52
+ MEDICAL = "Medical"
53
+ CODE = "Code"
44
54
 
45
55
 
46
56
  class TemplateNotFound(Exception):
@@ -81,6 +91,7 @@ def render_and_create(file_path: str, template: str, context: Dict[str, str]) ->
81
91
  create_file(file_path, content)
82
92
 
83
93
 
94
+ # pylint: disable=too-many-locals,too-many-branches,too-many-statements
84
95
  def new(
85
96
  project_name: Annotated[
86
97
  Optional[str],
@@ -125,6 +136,19 @@ def new(
125
136
 
126
137
  framework_str = framework_str.lower()
127
138
 
139
+ if framework_str == "flowertune":
140
+ llm_challenge_value = prompt_options(
141
+ "Please select LLM challenge by typing in the number",
142
+ sorted([challenge.value for challenge in LlmChallengeName]),
143
+ )
144
+ selected_value = [
145
+ name
146
+ for name, value in vars(LlmChallengeName).items()
147
+ if value == llm_challenge_value
148
+ ]
149
+ llm_challenge_str = selected_value[0]
150
+ llm_challenge_str = llm_challenge_str.lower()
151
+
128
152
  print(
129
153
  typer.style(
130
154
  f"\n🔨 Creating Flower project {project_name}...",
@@ -139,33 +163,6 @@ def new(
139
163
  import_name = package_name.replace("-", "_")
140
164
  project_dir = os.path.join(cwd, package_name)
141
165
 
142
- # List of files to render
143
- files = {
144
- ".gitignore": {"template": "app/.gitignore.tpl"},
145
- "README.md": {"template": "app/README.md.tpl"},
146
- "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
147
- f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
148
- f"{import_name}/server.py": {
149
- "template": f"app/code/server.{framework_str}.py.tpl"
150
- },
151
- f"{import_name}/client.py": {
152
- "template": f"app/code/client.{framework_str}.py.tpl"
153
- },
154
- }
155
-
156
- # Depending on the framework, generate task.py file
157
- frameworks_with_tasks = [
158
- MlFramework.PYTORCH.value.lower(),
159
- MlFramework.JAX.value.lower(),
160
- MlFramework.HUGGINGFACE.value.lower(),
161
- MlFramework.MLX.value.lower(),
162
- MlFramework.TENSORFLOW.value.lower(),
163
- ]
164
- if framework_str in frameworks_with_tasks:
165
- files[f"{import_name}/task.py"] = {
166
- "template": f"app/code/task.{framework_str}.py.tpl"
167
- }
168
-
169
166
  context = {
170
167
  "project_name": project_name,
171
168
  "package_name": package_name,
@@ -173,6 +170,85 @@ def new(
173
170
  "username": username,
174
171
  }
175
172
 
173
+ # List of files to render
174
+ if framework_str == "flowertune":
175
+ files = {
176
+ ".gitignore": {"template": "app/.gitignore.tpl"},
177
+ "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
178
+ "README.md": {"template": f"app/README.{framework_str}.md.tpl"},
179
+ f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
180
+ f"{import_name}/server.py": {
181
+ "template": "app/code/flwr_tune/server.py.tpl"
182
+ },
183
+ f"{import_name}/client.py": {
184
+ "template": "app/code/flwr_tune/client.py.tpl"
185
+ },
186
+ f"{import_name}/app.py": {"template": "app/code/flwr_tune/app.py.tpl"},
187
+ f"{import_name}/models.py": {
188
+ "template": "app/code/flwr_tune/models.py.tpl"
189
+ },
190
+ f"{import_name}/dataset.py": {
191
+ "template": "app/code/flwr_tune/dataset.py.tpl"
192
+ },
193
+ f"{import_name}/conf/config.yaml": {
194
+ "template": "app/code/flwr_tune/config.yaml.tpl"
195
+ },
196
+ f"{import_name}/conf/static_config.yaml": {
197
+ "template": "app/code/flwr_tune/static_config.yaml.tpl"
198
+ },
199
+ }
200
+
201
+ # Challenge specific context
202
+ fraction_fit = "0.2" if llm_challenge_str == "code" else "0.1"
203
+ if llm_challenge_str == "generalnlp":
204
+ challenge_name = "General NLP"
205
+ num_clients = "20"
206
+ dataset_name = "vicgalle/alpaca-gpt4"
207
+ elif llm_challenge_str == "finance":
208
+ challenge_name = "Finance"
209
+ num_clients = "50"
210
+ dataset_name = "FinGPT/fingpt-sentiment-train"
211
+ elif llm_challenge_str == "medical":
212
+ challenge_name = "Medical"
213
+ num_clients = "20"
214
+ dataset_name = "medalpaca/medical_meadow_medical_flashcards"
215
+ else:
216
+ challenge_name = "Code"
217
+ num_clients = "10"
218
+ dataset_name = "lucasmccabe-lmi/CodeAlpaca-20k"
219
+
220
+ context["llm_challenge_str"] = llm_challenge_str
221
+ context["fraction_fit"] = fraction_fit
222
+ context["challenge_name"] = challenge_name
223
+ context["num_clients"] = num_clients
224
+ context["dataset_name"] = dataset_name
225
+ else:
226
+ files = {
227
+ ".gitignore": {"template": "app/.gitignore.tpl"},
228
+ "README.md": {"template": "app/README.md.tpl"},
229
+ "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
230
+ f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
231
+ f"{import_name}/server.py": {
232
+ "template": f"app/code/server.{framework_str}.py.tpl"
233
+ },
234
+ f"{import_name}/client.py": {
235
+ "template": f"app/code/client.{framework_str}.py.tpl"
236
+ },
237
+ }
238
+
239
+ # Depending on the framework, generate task.py file
240
+ frameworks_with_tasks = [
241
+ MlFramework.PYTORCH.value.lower(),
242
+ MlFramework.JAX.value.lower(),
243
+ MlFramework.HUGGINGFACE.value.lower(),
244
+ MlFramework.MLX.value.lower(),
245
+ MlFramework.TENSORFLOW.value.lower(),
246
+ ]
247
+ if framework_str in frameworks_with_tasks:
248
+ files[f"{import_name}/task.py"] = {
249
+ "template": f"app/code/task.{framework_str}.py.tpl"
250
+ }
251
+
176
252
  for file_path, value in files.items():
177
253
  render_and_create(
178
254
  file_path=os.path.join(project_dir, file_path),
@@ -190,7 +266,7 @@ def new(
190
266
  )
191
267
  print(
192
268
  typer.style(
193
- f" cd {project_name}\n" + " pip install -e .\n flwr run\n",
269
+ f" cd {package_name}\n" + " pip install -e .\n flwr run\n",
194
270
  fg=typer.colors.BRIGHT_CYAN,
195
271
  bold=True,
196
272
  )
@@ -0,0 +1,56 @@
1
+ # FlowerTune LLM on $challenge_name Dataset
2
+
3
+ This directory conducts federated instruction tuning with a pretrained [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model on a [$challenge_name dataset](https://huggingface.co/datasets/$dataset_name).
4
+ We use [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the dataset.
5
+ Flower's Simulation Engine is used to simulate the LLM fine-tuning process in federated way,
6
+ which allows users to perform the training on a single GPU.
7
+
8
+
9
+ ## Methodology
10
+
11
+ This baseline performs federated LLM fine-tuning with [LoRA](https://arxiv.org/pdf/2106.09685) using the [🤗PEFT](https://huggingface.co/docs/peft/en/index) library.
12
+ The clients' models are aggregated with FedAvg strategy.
13
+ This provides a baseline performance for the leaderboard of $challenge_name challenge.
14
+
15
+
16
+ ## Environments setup
17
+
18
+ Project dependencies are defined in `pyproject.toml`. Install them in an activated Python environment with:
19
+
20
+ ```shell
21
+ pip install -e .
22
+ ```
23
+
24
+ ## Experimental setup
25
+
26
+ The dataset is partitioned into $num_clients shards with IID fashion serving as clients.
27
+ We randomly sample $fraction_fit clients to be available for each round,
28
+ and the federated fine-tuning lasts for `200` rounds.
29
+ All settings are defined in `$project_name/conf/static_config.yaml`, which is not allowed to be modified for fair competition if you plan to participated in the [LLM leaderboard](https://flower.ai/benchmarks/llm-leaderboard).
30
+
31
+
32
+ ## Running the challenge
33
+
34
+ First make sure that you have got the access to [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model with your Hugging-Face account. You can request access directly from the Hugging-Face website.
35
+ Then, follow the instruction [here](https://huggingface.co/docs/huggingface_hub/en/quick-start#login-command) to log in your account. Note you only need to complete this stage once in your development machine:
36
+
37
+ ```bash
38
+ huggingface-cli login
39
+ ```
40
+
41
+ Run the challenge with default config values.
42
+ The configs are in `$project_name/conf/config.yaml` and `$project_name/conf/static_config.yaml`, and are loaded automatically.
43
+
44
+ ```bash
45
+ flwr run
46
+ ```
47
+
48
+ ## VRAM consumption
49
+
50
+ We use Mistral-7B model with 4-bit quantization as default. The estimated VRAM consumption per client for each challenge is shown below:
51
+
52
+ | Challenges | GeneralNLP | Finance | Medical | Code |
53
+ | :--------: | :--------: | :--------: | :--------: | :--------: |
54
+ | VRAM | ~25.50 GB | ~17.30 GB | ~22.80 GB | ~17.40 GB |
55
+
56
+ You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which is specified with `flower.engine.simulation` in `pyproject.toml`.
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower CLI `new` command app / code / flwr_tune templates."""
@@ -0,0 +1,86 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ import os
4
+ import warnings
5
+ from datetime import datetime
6
+
7
+ from flwr_datasets import FederatedDataset
8
+ from hydra import compose, initialize
9
+ from hydra.utils import instantiate
10
+
11
+ from flwr.client import ClientApp
12
+ from flwr.common import ndarrays_to_parameters
13
+ from flwr.server import ServerApp, ServerConfig
14
+
15
+ from $import_name.client import gen_client_fn, get_parameters
16
+ from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
17
+ from $import_name.models import get_model
18
+ from $import_name.server import fit_weighted_average, get_evaluate_fn, get_on_fit_config
19
+
20
+ # Avoid warnings
21
+ warnings.filterwarnings("ignore", category=UserWarning)
22
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
23
+ os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
24
+
25
+ # Initialise regular config
26
+ with initialize(config_path="conf", version_base="1.1"):
27
+ cfg = compose(config_name="config")
28
+
29
+ # Initialise static config
30
+ with initialize(config_path="conf", version_base="1.1"):
31
+ cfg_static = compose(config_name="static_config")
32
+
33
+ cfg.train.num_rounds = cfg_static.num_rounds
34
+
35
+ # Create output directory given current timestamp
36
+ current_time = datetime.now()
37
+ folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
38
+ save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
39
+ os.makedirs(save_path, exist_ok=True)
40
+
41
+ # Partition dataset and get dataloaders
42
+ partitioner = instantiate(cfg_static.partitioner)
43
+ fds = FederatedDataset(
44
+ dataset=cfg_static.dataset.name, partitioners={"train": partitioner}
45
+ )
46
+ (
47
+ tokenizer,
48
+ data_collator,
49
+ formatting_prompts_func,
50
+ ) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)
51
+
52
+ # ClientApp for Flower Next
53
+ client = ClientApp(
54
+ client_fn=gen_client_fn(
55
+ fds,
56
+ tokenizer,
57
+ formatting_prompts_func,
58
+ data_collator,
59
+ cfg.model,
60
+ cfg.train,
61
+ save_path,
62
+ ),
63
+ )
64
+
65
+ # Get initial model weights
66
+ init_model = get_model(cfg.model)
67
+ init_model_parameters = get_parameters(init_model)
68
+ init_model_parameters = ndarrays_to_parameters(init_model_parameters)
69
+
70
+ # Instantiate strategy according to config. Here we pass other arguments
71
+ # that are only defined at runtime.
72
+ strategy = instantiate(
73
+ cfg.strategy,
74
+ on_fit_config_fn=get_on_fit_config(),
75
+ fit_metrics_aggregation_fn=fit_weighted_average,
76
+ initial_parameters=init_model_parameters,
77
+ evaluate_fn=get_evaluate_fn(
78
+ cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
79
+ ),
80
+ )
81
+
82
+ # ServerApp for Flower Next
83
+ server = ServerApp(
84
+ config=ServerConfig(num_rounds=cfg_static.num_rounds),
85
+ strategy=strategy,
86
+ )
@@ -0,0 +1,124 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ from collections import OrderedDict
4
+ from typing import Callable, Dict, Tuple
5
+
6
+ import torch
7
+ from omegaconf import DictConfig
8
+ from peft import get_peft_model_state_dict, set_peft_model_state_dict
9
+ from transformers import TrainingArguments
10
+ from trl import SFTTrainer
11
+
12
+ from flwr.client import NumPyClient
13
+ from flwr.common.typing import NDArrays, Scalar
14
+ from $import_name.dataset import reformat
15
+ from $import_name.models import cosine_annealing, get_model
16
+
17
+
18
+ # pylint: disable=too-many-arguments
19
+ # pylint: disable=too-many-instance-attributes
20
+ class FlowerClient(NumPyClient):
21
+ """Standard Flower client for CNN training."""
22
+
23
+ def __init__(
24
+ self,
25
+ model_cfg: DictConfig,
26
+ train_cfg: DictConfig,
27
+ trainset,
28
+ tokenizer,
29
+ formatting_prompts_func,
30
+ data_collator,
31
+ save_path,
32
+ ): # pylint: disable=too-many-arguments
33
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
+ self.train_cfg = train_cfg
35
+ self.training_argumnets = TrainingArguments(**train_cfg.training_arguments)
36
+ self.tokenizer = tokenizer
37
+ self.formatting_prompts_func = formatting_prompts_func
38
+ self.data_collator = data_collator
39
+ self.save_path = save_path
40
+
41
+ # instantiate model
42
+ self.model = get_model(model_cfg)
43
+
44
+ self.trainset = trainset
45
+
46
+ def fit(
47
+ self, parameters: NDArrays, config: Dict[str, Scalar]
48
+ ) -> Tuple[NDArrays, int, Dict]:
49
+ """Implement distributed fit function for a given client."""
50
+ set_parameters(self.model, parameters)
51
+
52
+ new_lr = cosine_annealing(
53
+ int(config["current_round"]),
54
+ self.train_cfg.num_rounds,
55
+ self.train_cfg.learning_rate_max,
56
+ self.train_cfg.learning_rate_min,
57
+ )
58
+
59
+ self.training_argumnets.learning_rate = new_lr
60
+ self.training_argumnets.output_dir = self.save_path
61
+
62
+ # Construct trainer
63
+ trainer = SFTTrainer(
64
+ model=self.model,
65
+ tokenizer=self.tokenizer,
66
+ args=self.training_argumnets,
67
+ max_seq_length=self.train_cfg.seq_length,
68
+ train_dataset=self.trainset,
69
+ formatting_func=self.formatting_prompts_func,
70
+ data_collator=self.data_collator,
71
+ )
72
+
73
+ # Do local training
74
+ results = trainer.train()
75
+
76
+ return (
77
+ get_parameters(self.model),
78
+ len(self.trainset),
79
+ {"train_loss": results.training_loss},
80
+ )
81
+
82
+
83
+ def set_parameters(model, parameters: NDArrays) -> None:
84
+ """Change the parameters of the model using the given ones."""
85
+ peft_state_dict_keys = get_peft_model_state_dict(model).keys()
86
+ params_dict = zip(peft_state_dict_keys, parameters)
87
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
88
+ set_peft_model_state_dict(model, state_dict)
89
+
90
+
91
+ def get_parameters(model) -> NDArrays:
92
+ """Return the parameters of the current net."""
93
+ state_dict = get_peft_model_state_dict(model)
94
+ return [val.cpu().numpy() for _, val in state_dict.items()]
95
+
96
+
97
+ def gen_client_fn(
98
+ fds,
99
+ tokenizer,
100
+ formatting_prompts_func,
101
+ data_collator,
102
+ model_cfg: DictConfig,
103
+ train_cfg: DictConfig,
104
+ save_path: str,
105
+ ) -> Callable[[str], FlowerClient]: # pylint: disable=too-many-arguments
106
+ """Generate the client function that creates the Flower Clients."""
107
+
108
+ def client_fn(cid: str) -> FlowerClient:
109
+ """Create a Flower client representing a single organization."""
110
+ # Let's get the partition corresponding to the i-th client
111
+ client_trainset = fds.load_partition(int(cid), "train")
112
+ client_trainset = reformat(client_trainset, llm_task="$llm_challenge_str")
113
+
114
+ return FlowerClient(
115
+ model_cfg,
116
+ train_cfg,
117
+ client_trainset,
118
+ tokenizer,
119
+ formatting_prompts_func,
120
+ data_collator,
121
+ save_path,
122
+ ).to_client()
123
+
124
+ return client_fn
@@ -0,0 +1,34 @@
1
+ # Federated Instruction Tuning
2
+ ---
3
+ model:
4
+ name: "mistralai/Mistral-7B-v0.3"
5
+ quantization: 4 # 8 or 4 if you want to do quantization with BitsAndBytes
6
+ gradient_checkpointing: True
7
+ lora:
8
+ peft_lora_r: 32
9
+ peft_lora_alpha: 64
10
+
11
+ train:
12
+ num_rounds: null
13
+ save_every_round: 5
14
+ learning_rate_max: 5e-5
15
+ learning_rate_min: 1e-6
16
+ seq_length: 512
17
+ training_arguments:
18
+ output_dir: null # to be set by hydra
19
+ learning_rate: null # to be set by the client
20
+ per_device_train_batch_size: 16
21
+ gradient_accumulation_steps: 1
22
+ logging_steps: 10
23
+ num_train_epochs: 3
24
+ max_steps: 10
25
+ report_to: null
26
+ save_steps: 1000
27
+ save_total_limit: 10
28
+ gradient_checkpointing: True
29
+ lr_scheduler_type: "constant"
30
+
31
+ strategy:
32
+ _target_: flwr.server.strategy.FedAvg
33
+ fraction_fit: $fraction_fit
34
+ fraction_evaluate: 0.0 # no client evaluation
@@ -0,0 +1,57 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ from transformers import AutoTokenizer
4
+ from trl import DataCollatorForCompletionOnlyLM
5
+
6
+
7
+ def formatting_prompts_func(example):
8
+ """Construct prompts."""
9
+ output_texts = []
10
+ # Constructing a standard Alpaca
11
+ # (https://github.com/tatsu-lab/stanford_alpaca#data-release) prompt
12
+ mssg = (
13
+ "Below is an instruction that describes a task. "
14
+ "Write a response that appropriately completes the request."
15
+ )
16
+ for i in range(len(example["instruction"])):
17
+ text = (
18
+ f"{mssg}\n### Instruction:\n{example['instruction'][i]}\n"
19
+ f"### Response: {example['response'][i]}"
20
+ )
21
+ output_texts.append(text)
22
+ return output_texts
23
+
24
+
25
+ def get_tokenizer_and_data_collator_and_propt_formatting(model_name: str):
26
+ """Get tokenizer, data_collator and prompt formatting."""
27
+ # From: https://huggingface.co/docs/trl/en/sft_trainer
28
+ tokenizer = AutoTokenizer.from_pretrained(
29
+ model_name, use_fast=True, padding_side="right"
30
+ )
31
+ tokenizer.pad_token = tokenizer.eos_token
32
+ response_template_with_context = "\n### Response:" # alpaca response tag
33
+ response_template_ids = tokenizer.encode(
34
+ response_template_with_context, add_special_tokens=False
35
+ )[2:]
36
+ data_collator = DataCollatorForCompletionOnlyLM(
37
+ response_template_ids, tokenizer=tokenizer
38
+ )
39
+
40
+ return tokenizer, data_collator, formatting_prompts_func
41
+
42
+
43
+ def formatting(dataset):
44
+ """Format dataset."""
45
+ dataset["instruction"] = dataset["instruction"] + " " + dataset["input"]
46
+ return dataset
47
+
48
+
49
+ def reformat(dataset, llm_task):
50
+ """Reformat datasets."""
51
+ dataset = dataset.rename_column("output", "response")
52
+ if llm_task == "finance" or llm_task == "code":
53
+ dataset = dataset.map(formatting, remove_columns=["input"])
54
+ if llm_task == "medical":
55
+ dataset = dataset.remove_columns(["instruction"])
56
+ dataset = dataset.rename_column("input", "instruction")
57
+ return dataset
@@ -0,0 +1,59 @@
1
+ """$project_name: A Flower / FlowerTune app."""
2
+
3
+ import math
4
+
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from peft import LoraConfig, get_peft_model
8
+ from peft.utils import prepare_model_for_kbit_training
9
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig
10
+
11
+
12
+ def cosine_annealing(
13
+ current_round: int,
14
+ total_round: int,
15
+ lrate_max: float = 0.001,
16
+ lrate_min: float = 0.0,
17
+ ) -> float:
18
+ """Implement cosine annealing learning rate schedule."""
19
+ cos_inner = math.pi * current_round / total_round
20
+ return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner))
21
+
22
+
23
+ def get_model(model_cfg: DictConfig):
24
+ """Load model with appropriate quantization config and other optimizations.
25
+
26
+ Please refer to this example for `peft + BitsAndBytes`:
27
+ https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py
28
+ """
29
+ if model_cfg.quantization == 4:
30
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
31
+ elif model_cfg.quantization == 8:
32
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
33
+ else:
34
+ raise ValueError(
35
+ f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/"
36
+ )
37
+
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ model_cfg.name,
40
+ quantization_config=quantization_config,
41
+ torch_dtype=torch.bfloat16,
42
+ low_cpu_mem_usage=True,
43
+ )
44
+
45
+ model = prepare_model_for_kbit_training(
46
+ model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
47
+ )
48
+
49
+ peft_config = LoraConfig(
50
+ r=model_cfg.lora.peft_lora_r,
51
+ lora_alpha=model_cfg.lora.peft_lora_alpha,
52
+ lora_dropout=0.075,
53
+ task_type="CAUSAL_LM",
54
+ )
55
+
56
+ if model_cfg.gradient_checkpointing:
57
+ model.config.use_cache = False
58
+
59
+ return get_peft_model(model, peft_config)