flwr-nightly 1.11.0.dev20240826__py3-none-any.whl → 1.11.0.dev20240827__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/new/new.py +6 -10
- flwr/cli/new/templates/app/README.flowertune.md.tpl +16 -6
- flwr/cli/new/templates/app/code/flwr_tune/{client.py.tpl → client_app.py.tpl} +50 -40
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +32 -2
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -3
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +95 -0
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +34 -7
- flwr/cli/run/run.py +2 -2
- {flwr_nightly-1.11.0.dev20240826.dist-info → flwr_nightly-1.11.0.dev20240827.dist-info}/METADATA +2 -2
- {flwr_nightly-1.11.0.dev20240826.dist-info → flwr_nightly-1.11.0.dev20240827.dist-info}/RECORD +14 -16
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +0 -89
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +0 -34
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +0 -48
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +0 -11
- {flwr_nightly-1.11.0.dev20240826.dist-info → flwr_nightly-1.11.0.dev20240827.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240826.dist-info → flwr_nightly-1.11.0.dev20240827.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.11.0.dev20240826.dist-info → flwr_nightly-1.11.0.dev20240827.dist-info}/entry_points.txt +0 -0
flwr/cli/new/new.py
CHANGED
|
@@ -187,24 +187,20 @@ def new(
|
|
|
187
187
|
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
|
|
188
188
|
"README.md": {"template": f"app/README.{framework_str}.md.tpl"},
|
|
189
189
|
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
|
|
190
|
-
f"{import_name}/
|
|
191
|
-
"template": "app/code/flwr_tune/
|
|
190
|
+
f"{import_name}/server_app.py": {
|
|
191
|
+
"template": "app/code/flwr_tune/server_app.py.tpl"
|
|
192
192
|
},
|
|
193
|
-
f"{import_name}/
|
|
194
|
-
"template": "app/code/flwr_tune/
|
|
193
|
+
f"{import_name}/client_app.py": {
|
|
194
|
+
"template": "app/code/flwr_tune/client_app.py.tpl"
|
|
195
195
|
},
|
|
196
|
-
f"{import_name}/app.py": {"template": "app/code/flwr_tune/app.py.tpl"},
|
|
197
196
|
f"{import_name}/models.py": {
|
|
198
197
|
"template": "app/code/flwr_tune/models.py.tpl"
|
|
199
198
|
},
|
|
200
199
|
f"{import_name}/dataset.py": {
|
|
201
200
|
"template": "app/code/flwr_tune/dataset.py.tpl"
|
|
202
201
|
},
|
|
203
|
-
f"{import_name}/
|
|
204
|
-
"template": "app/code/flwr_tune/
|
|
205
|
-
},
|
|
206
|
-
f"{import_name}/conf/static_config.yaml": {
|
|
207
|
-
"template": "app/code/flwr_tune/static_config.yaml.tpl"
|
|
202
|
+
f"{import_name}/strategy.py": {
|
|
203
|
+
"template": "app/code/flwr_tune/strategy.py.tpl"
|
|
208
204
|
},
|
|
209
205
|
}
|
|
210
206
|
|
|
@@ -23,10 +23,12 @@ pip install -e .
|
|
|
23
23
|
|
|
24
24
|
## Experimental setup
|
|
25
25
|
|
|
26
|
-
The dataset is
|
|
27
|
-
We randomly sample $fraction_fit
|
|
28
|
-
|
|
29
|
-
|
|
26
|
+
The dataset is divided into $num_clients partitions in an IID fashion, a partition is assigned to each ClientApp.
|
|
27
|
+
We randomly sample a fraction ($fraction_fit) of the total nodes to participate in each round, for a total of `200` rounds.
|
|
28
|
+
All settings are defined in `pyproject.toml`.
|
|
29
|
+
|
|
30
|
+
> [!IMPORTANT]
|
|
31
|
+
> Please note that `[tool.flwr.app.config.static]` and `options.num-supernodes` under `[tool.flwr.federations.local-simulation]` are not allowed to be modified for fair competition if you plan to participated in the [LLM leaderboard](https://flower.ai/benchmarks/llm-leaderboard).
|
|
30
32
|
|
|
31
33
|
|
|
32
34
|
## Running the challenge
|
|
@@ -39,7 +41,7 @@ huggingface-cli login
|
|
|
39
41
|
```
|
|
40
42
|
|
|
41
43
|
Run the challenge with default config values.
|
|
42
|
-
The configs are in
|
|
44
|
+
The configs are defined in `[tool.flwr.app.config]` entry of `pyproject.toml`, and are loaded automatically.
|
|
43
45
|
|
|
44
46
|
```bash
|
|
45
47
|
flwr run
|
|
@@ -53,4 +55,12 @@ We use Mistral-7B model with 4-bit quantization as default. The estimated VRAM c
|
|
|
53
55
|
| :--------: | :--------: | :--------: | :--------: | :--------: |
|
|
54
56
|
| VRAM | ~25.50 GB | ~17.30 GB | ~22.80 GB | ~17.40 GB |
|
|
55
57
|
|
|
56
|
-
You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which
|
|
58
|
+
You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which are specified with `options.backend.clientapp-cpus` and `options.backend.clientapp-gpus` under `[tool.flwr.federations.local-simulation]` entry in `pyproject.toml`.
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
## Model saving
|
|
62
|
+
|
|
63
|
+
The global PEFT model checkpoints are saved every 5 rounds after aggregation on the sever side as default, which can be specified with `train.save-every-round` under [tool.flwr.app.config] entry in `pyproject.toml`.
|
|
64
|
+
|
|
65
|
+
> [!NOTE]
|
|
66
|
+
> Please provide the last PEFT checkpoint if you plan to participated in the [LLM leaderboard](https://flower.ai/benchmarks/llm-leaderboard).
|
|
@@ -1,20 +1,32 @@
|
|
|
1
1
|
"""$project_name: A Flower / FlowerTune app."""
|
|
2
2
|
|
|
3
|
+
import os
|
|
4
|
+
import warnings
|
|
3
5
|
from collections import OrderedDict
|
|
4
|
-
from typing import
|
|
6
|
+
from typing import Dict, Tuple
|
|
5
7
|
|
|
6
8
|
import torch
|
|
9
|
+
from flwr.client import ClientApp, NumPyClient
|
|
10
|
+
from flwr.common import Context
|
|
11
|
+
from flwr.common.config import unflatten_dict
|
|
12
|
+
from flwr.common.typing import NDArrays, Scalar
|
|
7
13
|
from omegaconf import DictConfig
|
|
8
14
|
from peft import get_peft_model_state_dict, set_peft_model_state_dict
|
|
9
15
|
from transformers import TrainingArguments
|
|
10
16
|
from trl import SFTTrainer
|
|
11
17
|
|
|
12
|
-
from
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
18
|
+
from $import_name.dataset import (
|
|
19
|
+
get_tokenizer_and_data_collator_and_propt_formatting,
|
|
20
|
+
load_data,
|
|
21
|
+
replace_keys,
|
|
22
|
+
)
|
|
16
23
|
from $import_name.models import cosine_annealing, get_model
|
|
17
24
|
|
|
25
|
+
# Avoid warnings
|
|
26
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
27
|
+
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
|
|
28
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
29
|
+
|
|
18
30
|
|
|
19
31
|
# pylint: disable=too-many-arguments
|
|
20
32
|
# pylint: disable=too-many-instance-attributes
|
|
@@ -29,7 +41,7 @@ class FlowerClient(NumPyClient):
|
|
|
29
41
|
tokenizer,
|
|
30
42
|
formatting_prompts_func,
|
|
31
43
|
data_collator,
|
|
32
|
-
|
|
44
|
+
num_rounds,
|
|
33
45
|
): # pylint: disable=too-many-arguments
|
|
34
46
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
35
47
|
self.train_cfg = train_cfg
|
|
@@ -37,13 +49,12 @@ class FlowerClient(NumPyClient):
|
|
|
37
49
|
self.tokenizer = tokenizer
|
|
38
50
|
self.formatting_prompts_func = formatting_prompts_func
|
|
39
51
|
self.data_collator = data_collator
|
|
40
|
-
self.
|
|
52
|
+
self.num_rounds = num_rounds
|
|
53
|
+
self.trainset = trainset
|
|
41
54
|
|
|
42
55
|
# instantiate model
|
|
43
56
|
self.model = get_model(model_cfg)
|
|
44
57
|
|
|
45
|
-
self.trainset = trainset
|
|
46
|
-
|
|
47
58
|
def fit(
|
|
48
59
|
self, parameters: NDArrays, config: Dict[str, Scalar]
|
|
49
60
|
) -> Tuple[NDArrays, int, Dict]:
|
|
@@ -52,13 +63,13 @@ class FlowerClient(NumPyClient):
|
|
|
52
63
|
|
|
53
64
|
new_lr = cosine_annealing(
|
|
54
65
|
int(config["current_round"]),
|
|
55
|
-
self.
|
|
66
|
+
self.num_rounds,
|
|
56
67
|
self.train_cfg.learning_rate_max,
|
|
57
68
|
self.train_cfg.learning_rate_min,
|
|
58
69
|
)
|
|
59
70
|
|
|
60
71
|
self.training_argumnets.learning_rate = new_lr
|
|
61
|
-
self.training_argumnets.output_dir =
|
|
72
|
+
self.training_argumnets.output_dir = config["save_path"]
|
|
62
73
|
|
|
63
74
|
# Construct trainer
|
|
64
75
|
trainer = SFTTrainer(
|
|
@@ -95,32 +106,31 @@ def get_parameters(model) -> NDArrays:
|
|
|
95
106
|
return [val.cpu().numpy() for _, val in state_dict.items()]
|
|
96
107
|
|
|
97
108
|
|
|
98
|
-
def
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
return client_fn
|
|
109
|
+
def client_fn(context: Context) -> FlowerClient:
|
|
110
|
+
"""Create a Flower client representing a single organization."""
|
|
111
|
+
partition_id = context.node_config["partition-id"]
|
|
112
|
+
num_partitions = context.node_config["num-partitions"]
|
|
113
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
114
|
+
cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
|
|
115
|
+
|
|
116
|
+
# Let's get the client partition
|
|
117
|
+
client_trainset = load_data(partition_id, num_partitions, cfg.static.dataset.name)
|
|
118
|
+
(
|
|
119
|
+
tokenizer,
|
|
120
|
+
data_collator,
|
|
121
|
+
formatting_prompts_func,
|
|
122
|
+
) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)
|
|
123
|
+
|
|
124
|
+
return FlowerClient(
|
|
125
|
+
cfg.model,
|
|
126
|
+
cfg.train,
|
|
127
|
+
client_trainset,
|
|
128
|
+
tokenizer,
|
|
129
|
+
formatting_prompts_func,
|
|
130
|
+
data_collator,
|
|
131
|
+
num_rounds,
|
|
132
|
+
).to_client()
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# Flower ClientApp
|
|
136
|
+
app = ClientApp(client_fn)
|
|
@@ -1,8 +1,12 @@
|
|
|
1
1
|
"""$project_name: A Flower / FlowerTune app."""
|
|
2
2
|
|
|
3
|
+
from flwr_datasets import FederatedDataset
|
|
4
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
3
5
|
from transformers import AutoTokenizer
|
|
4
6
|
from trl import DataCollatorForCompletionOnlyLM
|
|
5
7
|
|
|
8
|
+
FDS = None # Cache FederatedDataset
|
|
9
|
+
|
|
6
10
|
|
|
7
11
|
def formatting_prompts_func(example):
|
|
8
12
|
"""Construct prompts."""
|
|
@@ -24,7 +28,6 @@ def formatting_prompts_func(example):
|
|
|
24
28
|
|
|
25
29
|
def get_tokenizer_and_data_collator_and_propt_formatting(model_name: str):
|
|
26
30
|
"""Get tokenizer, data_collator and prompt formatting."""
|
|
27
|
-
# From: https://huggingface.co/docs/trl/en/sft_trainer
|
|
28
31
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
29
32
|
model_name, use_fast=True, padding_side="right"
|
|
30
33
|
)
|
|
@@ -49,9 +52,36 @@ def formatting(dataset):
|
|
|
49
52
|
def reformat(dataset, llm_task):
|
|
50
53
|
"""Reformat datasets."""
|
|
51
54
|
dataset = dataset.rename_column("output", "response")
|
|
52
|
-
if llm_task
|
|
55
|
+
if llm_task in ["finance", "code"]:
|
|
53
56
|
dataset = dataset.map(formatting, remove_columns=["input"])
|
|
54
57
|
if llm_task == "medical":
|
|
55
58
|
dataset = dataset.remove_columns(["instruction"])
|
|
56
59
|
dataset = dataset.rename_column("input", "instruction")
|
|
57
60
|
return dataset
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def load_data(partition_id: int, num_partitions: int, dataset_name: str):
|
|
64
|
+
"""Load partition data."""
|
|
65
|
+
# Only initialize `FederatedDataset` once
|
|
66
|
+
global FDS
|
|
67
|
+
if FDS is None:
|
|
68
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
69
|
+
FDS = FederatedDataset(
|
|
70
|
+
dataset=dataset_name,
|
|
71
|
+
partitioners={"train": partitioner},
|
|
72
|
+
)
|
|
73
|
+
client_trainset = FDS.load_partition(partition_id, "train")
|
|
74
|
+
client_trainset = reformat(client_trainset, llm_task="generalnlp")
|
|
75
|
+
return client_trainset
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def replace_keys(input_dict, match="-", target="_"):
|
|
79
|
+
"""Recursively replace match string with target string in dictionary keys."""
|
|
80
|
+
new_dict = {}
|
|
81
|
+
for key, value in input_dict.items():
|
|
82
|
+
new_key = key.replace(match, target)
|
|
83
|
+
if isinstance(value, dict):
|
|
84
|
+
new_dict[new_key] = replace_keys(value, match, target)
|
|
85
|
+
else:
|
|
86
|
+
new_dict[new_key] = value
|
|
87
|
+
return new_dict
|
|
@@ -22,9 +22,6 @@ def cosine_annealing(
|
|
|
22
22
|
|
|
23
23
|
def get_model(model_cfg: DictConfig):
|
|
24
24
|
"""Load model with appropriate quantization config and other optimizations.
|
|
25
|
-
|
|
26
|
-
Please refer to this example for `peft + BitsAndBytes`:
|
|
27
|
-
https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py
|
|
28
25
|
"""
|
|
29
26
|
if model_cfg.quantization == 4:
|
|
30
27
|
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
|
|
6
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
7
|
+
from flwr.common.config import unflatten_dict
|
|
8
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
9
|
+
from omegaconf import DictConfig
|
|
10
|
+
|
|
11
|
+
from $import_name.client_app import get_parameters, set_parameters
|
|
12
|
+
from $import_name.models import get_model
|
|
13
|
+
from $import_name.dataset import replace_keys
|
|
14
|
+
from $import_name.strategy import FlowerTuneLlm
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Get function that will be executed by the strategy's evaluate() method
|
|
18
|
+
# Here we use it to save global model checkpoints
|
|
19
|
+
def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
|
|
20
|
+
"""Return an evaluation function for saving global model."""
|
|
21
|
+
|
|
22
|
+
def evaluate(server_round: int, parameters, config):
|
|
23
|
+
# Save model
|
|
24
|
+
if server_round != 0 and (
|
|
25
|
+
server_round == total_round or server_round % save_every_round == 0
|
|
26
|
+
):
|
|
27
|
+
# Init model
|
|
28
|
+
model = get_model(model_cfg)
|
|
29
|
+
set_parameters(model, parameters)
|
|
30
|
+
|
|
31
|
+
model.save_pretrained(f"{save_path}/peft_{server_round}")
|
|
32
|
+
|
|
33
|
+
return 0.0, {}
|
|
34
|
+
|
|
35
|
+
return evaluate
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_on_fit_config(save_path):
|
|
39
|
+
"""Return a function that will be used to construct the config that the
|
|
40
|
+
client's fit() method will receive."""
|
|
41
|
+
|
|
42
|
+
def fit_config_fn(server_round: int):
|
|
43
|
+
fit_config = {}
|
|
44
|
+
fit_config["current_round"] = server_round
|
|
45
|
+
fit_config["save_path"] = save_path
|
|
46
|
+
return fit_config
|
|
47
|
+
|
|
48
|
+
return fit_config_fn
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def fit_weighted_average(metrics):
|
|
52
|
+
"""Aggregate (federated) evaluation metrics."""
|
|
53
|
+
# Multiply accuracy of each client by number of examples used
|
|
54
|
+
losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
|
|
55
|
+
examples = [num_examples for num_examples, _ in metrics]
|
|
56
|
+
|
|
57
|
+
# Aggregate and return custom metric (weighted average)
|
|
58
|
+
return {"train_loss": sum(losses) / sum(examples)}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def server_fn(context: Context):
|
|
62
|
+
"""Construct components that set the ServerApp behaviour."""
|
|
63
|
+
# Create output directory given current timestamp
|
|
64
|
+
current_time = datetime.now()
|
|
65
|
+
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
|
|
66
|
+
save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
|
|
67
|
+
os.makedirs(save_path, exist_ok=True)
|
|
68
|
+
|
|
69
|
+
# Read from config
|
|
70
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
71
|
+
cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
|
|
72
|
+
|
|
73
|
+
# Get initial model weights
|
|
74
|
+
init_model = get_model(cfg.model)
|
|
75
|
+
init_model_parameters = get_parameters(init_model)
|
|
76
|
+
init_model_parameters = ndarrays_to_parameters(init_model_parameters)
|
|
77
|
+
|
|
78
|
+
# Define strategy
|
|
79
|
+
strategy = FlowerTuneLlm(
|
|
80
|
+
fraction_fit=cfg.strategy.fraction_fit,
|
|
81
|
+
fraction_evaluate=cfg.strategy.fraction_evaluate,
|
|
82
|
+
on_fit_config_fn=get_on_fit_config(save_path),
|
|
83
|
+
fit_metrics_aggregation_fn=fit_weighted_average,
|
|
84
|
+
initial_parameters=init_model_parameters,
|
|
85
|
+
evaluate_fn=get_evaluate_fn(
|
|
86
|
+
cfg.model, cfg.train.save_every_round, num_rounds, save_path
|
|
87
|
+
),
|
|
88
|
+
)
|
|
89
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
90
|
+
|
|
91
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Flower ServerApp
|
|
95
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
from io import BytesIO
|
|
4
|
+
from logging import INFO, WARN
|
|
5
|
+
from typing import List, Tuple, Union
|
|
6
|
+
|
|
7
|
+
from flwr.common import FitIns, FitRes, Parameters, log, parameters_to_ndarrays
|
|
8
|
+
from flwr.server.client_manager import ClientManager
|
|
9
|
+
from flwr.server.client_proxy import ClientProxy
|
|
10
|
+
from flwr.server.strategy import FedAvg
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FlowerTuneLlm(FedAvg):
|
|
14
|
+
"""Customised FedAvg strategy implementation.
|
|
15
|
+
|
|
16
|
+
This class behaves just like FedAvg but also tracks the communication
|
|
17
|
+
costs associated with `fit` over FL rounds.
|
|
18
|
+
"""
|
|
19
|
+
def __init__(self, **kwargs):
|
|
20
|
+
super().__init__(**kwargs)
|
|
21
|
+
self.comm_tracker = CommunicationTracker()
|
|
22
|
+
|
|
23
|
+
def configure_fit(
|
|
24
|
+
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
|
25
|
+
):
|
|
26
|
+
"""Configure the next round of training."""
|
|
27
|
+
return_clients = super().configure_fit(server_round, parameters, client_manager)
|
|
28
|
+
|
|
29
|
+
# Test communication costs
|
|
30
|
+
fit_ins_list = [fit_ins for _, fit_ins in return_clients]
|
|
31
|
+
self.comm_tracker.track(fit_ins_list)
|
|
32
|
+
|
|
33
|
+
return return_clients
|
|
34
|
+
|
|
35
|
+
def aggregate_fit(
|
|
36
|
+
self,
|
|
37
|
+
server_round: int,
|
|
38
|
+
results: List[Tuple[ClientProxy, FitRes]],
|
|
39
|
+
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
|
|
40
|
+
):
|
|
41
|
+
"""Aggregate fit results using weighted average."""
|
|
42
|
+
# Test communication costs
|
|
43
|
+
fit_res_list = [fit_res for _, fit_res in results]
|
|
44
|
+
self.comm_tracker.track(fit_res_list)
|
|
45
|
+
|
|
46
|
+
parameters_aggregated, metrics_aggregated = super().aggregate_fit(
|
|
47
|
+
server_round, results, failures
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return parameters_aggregated, metrics_aggregated
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class CommunicationTracker:
|
|
54
|
+
"""Communication costs tracker over FL rounds."""
|
|
55
|
+
def __init__(self):
|
|
56
|
+
self.curr_comm_cost = 0.0
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def _compute_bytes(parameters):
|
|
60
|
+
return sum([BytesIO(t).getbuffer().nbytes for t in parameters.tensors])
|
|
61
|
+
|
|
62
|
+
def track(self, fit_list: List[Union[FitIns, FitRes]]):
|
|
63
|
+
size_bytes_list = [
|
|
64
|
+
self._compute_bytes(fit_ele.parameters)
|
|
65
|
+
for fit_ele in fit_list
|
|
66
|
+
]
|
|
67
|
+
comm_cost = sum(size_bytes_list) / 1024**2
|
|
68
|
+
|
|
69
|
+
self.curr_comm_cost += comm_cost
|
|
70
|
+
log(
|
|
71
|
+
INFO,
|
|
72
|
+
"Communication budget: used %.2f MB (+%.2f MB this round) / 200,000 MB",
|
|
73
|
+
self.curr_comm_cost,
|
|
74
|
+
comm_cost,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if self.curr_comm_cost > 2e5:
|
|
78
|
+
log(
|
|
79
|
+
WARN,
|
|
80
|
+
"The accumulated communication cost has exceeded 200,000 MB. "
|
|
81
|
+
"Please consider reducing it if you plan to participate "
|
|
82
|
+
"FlowerTune LLM Leaderboard.",
|
|
83
|
+
)
|
|
@@ -8,15 +8,15 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
12
|
-
"flwr-datasets>=0.
|
|
13
|
-
"hydra-core==1.3.2",
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
|
+
"flwr-datasets>=0.3.0",
|
|
14
13
|
"trl==0.8.1",
|
|
15
14
|
"bitsandbytes==0.43.0",
|
|
16
15
|
"scipy==1.13.0",
|
|
17
16
|
"peft==0.6.2",
|
|
18
17
|
"transformers==4.39.3",
|
|
19
18
|
"sentencepiece==0.2.0",
|
|
19
|
+
"omegaconf==2.3.0",
|
|
20
20
|
]
|
|
21
21
|
|
|
22
22
|
[tool.hatch.build.targets.wheel]
|
|
@@ -26,14 +26,41 @@ packages = ["."]
|
|
|
26
26
|
publisher = "$username"
|
|
27
27
|
|
|
28
28
|
[tool.flwr.app.components]
|
|
29
|
-
serverapp = "$import_name.app
|
|
30
|
-
clientapp = "$import_name.app
|
|
29
|
+
serverapp = "$import_name.server_app:app"
|
|
30
|
+
clientapp = "$import_name.client_app:app"
|
|
31
31
|
|
|
32
32
|
[tool.flwr.app.config]
|
|
33
|
-
|
|
33
|
+
model.name = "mistralai/Mistral-7B-v0.3"
|
|
34
|
+
model.quantization = 4
|
|
35
|
+
model.gradient-checkpointing = true
|
|
36
|
+
model.lora.peft-lora-r = 32
|
|
37
|
+
model.lora.peft-lora-alpha = 64
|
|
38
|
+
train.save-every-round = 5
|
|
39
|
+
train.learning-rate-max = 5e-5
|
|
40
|
+
train.learning-rate-min = 1e-6
|
|
41
|
+
train.seq-length = 512
|
|
42
|
+
train.training-arguments.output-dir = ""
|
|
43
|
+
train.training-arguments.learning-rate = ""
|
|
44
|
+
train.training-arguments.per-device-train-batch-size = 16
|
|
45
|
+
train.training-arguments.gradient-accumulation-steps = 1
|
|
46
|
+
train.training-arguments.logging-steps = 10
|
|
47
|
+
train.training-arguments.num-train-epochs = 3
|
|
48
|
+
train.training-arguments.max-steps = 10
|
|
49
|
+
train.training-arguments.save-steps = 1000
|
|
50
|
+
train.training-arguments.save-total-limit = 10
|
|
51
|
+
train.training-arguments.gradient-checkpointing = true
|
|
52
|
+
train.training-arguments.lr-scheduler-type = "constant"
|
|
53
|
+
strategy.fraction-fit = $fraction_fit
|
|
54
|
+
strategy.fraction-evaluate = 0.0
|
|
55
|
+
num-server-rounds = 200
|
|
56
|
+
|
|
57
|
+
[tool.flwr.app.config.static]
|
|
58
|
+
dataset.name = "$dataset_name"
|
|
34
59
|
|
|
35
60
|
[tool.flwr.federations]
|
|
36
61
|
default = "local-simulation"
|
|
37
62
|
|
|
38
63
|
[tool.flwr.federations.local-simulation]
|
|
39
|
-
options.num-supernodes =
|
|
64
|
+
options.num-supernodes = $num_clients
|
|
65
|
+
options.backend.client-resources.num-cpus = 6
|
|
66
|
+
options.backend.client-resources.num-gpus = 1.0
|
flwr/cli/run/run.py
CHANGED
|
@@ -124,14 +124,14 @@ def run(
|
|
|
124
124
|
|
|
125
125
|
|
|
126
126
|
def _run_with_superexec(
|
|
127
|
-
app:
|
|
127
|
+
app: Path,
|
|
128
128
|
federation_config: Dict[str, Any],
|
|
129
129
|
config_overrides: Optional[List[str]],
|
|
130
130
|
) -> None:
|
|
131
131
|
|
|
132
132
|
insecure_str = federation_config.get("insecure")
|
|
133
133
|
if root_certificates := federation_config.get("root-certificates"):
|
|
134
|
-
root_certificates_bytes =
|
|
134
|
+
root_certificates_bytes = (app / root_certificates).read_bytes()
|
|
135
135
|
if insecure := bool(insecure_str):
|
|
136
136
|
typer.secho(
|
|
137
137
|
"❌ `root_certificates` were provided but the `insecure` parameter"
|
{flwr_nightly-1.11.0.dev20240826.dist-info → flwr_nightly-1.11.0.dev20240827.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: flwr-nightly
|
|
3
|
-
Version: 1.11.0.
|
|
3
|
+
Version: 1.11.0.dev20240827
|
|
4
4
|
Summary: Flower: A Friendly Federated Learning Framework
|
|
5
5
|
Home-page: https://flower.ai
|
|
6
6
|
License: Apache-2.0
|
|
@@ -195,7 +195,7 @@ Other [examples](https://github.com/adap/flower/tree/main/examples):
|
|
|
195
195
|
- [PyTorch: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/pytorch-from-centralized-to-federated)
|
|
196
196
|
- [Vertical FL](https://github.com/adap/flower/tree/main/examples/vertical-fl)
|
|
197
197
|
- [Federated Finetuning of OpenAI's Whisper](https://github.com/adap/flower/tree/main/examples/whisper-federated-finetuning)
|
|
198
|
-
- [Federated Finetuning of Large Language Model](https://github.com/adap/flower/tree/main/examples/llm
|
|
198
|
+
- [Federated Finetuning of Large Language Model](https://github.com/adap/flower/tree/main/examples/flowertune-llm)
|
|
199
199
|
- [Federated Finetuning of a Vision Transformer](https://github.com/adap/flower/tree/main/examples/flowertune-vit)
|
|
200
200
|
- [Advanced Flower with TensorFlow/Keras](https://github.com/adap/flower/tree/main/examples/advanced-tensorflow)
|
|
201
201
|
- [Advanced Flower with PyTorch](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)
|
{flwr_nightly-1.11.0.dev20240826.dist-info → flwr_nightly-1.11.0.dev20240827.dist-info}/RECORD
RENAMED
|
@@ -6,10 +6,10 @@ flwr/cli/config_utils.py,sha256=mDGXbcIxG14UpkUplILBYUkSk5M1LeTzZYDGNx-pFpU,7540
|
|
|
6
6
|
flwr/cli/example.py,sha256=1bGDYll3BXQY2kRqSN-oICqS5n1b9m0g0RvXTopXHl4,2215
|
|
7
7
|
flwr/cli/install.py,sha256=tUncrbZYRbC9QEcWSeTER16plPEoU-ERP0-nMgWiSPo,7094
|
|
8
8
|
flwr/cli/new/__init__.py,sha256=cQzK1WH4JP2awef1t2UQ2xjl1agVEz9rwutV18SWV1k,789
|
|
9
|
-
flwr/cli/new/new.py,sha256=
|
|
9
|
+
flwr/cli/new/new.py,sha256=Q4iNccbodwfZY2th-Ws83txm-VolI236dscOziDVX2c,9396
|
|
10
10
|
flwr/cli/new/templates/__init__.py,sha256=4luU8RL-CK8JJCstQ_ON809W9bNTkY1l9zSaPKBkgwY,725
|
|
11
11
|
flwr/cli/new/templates/app/.gitignore.tpl,sha256=XixnHdyeMB2vwkGtGnwHqoWpH-9WChdyG0GXe57duhc,3078
|
|
12
|
-
flwr/cli/new/templates/app/README.flowertune.md.tpl,sha256=
|
|
12
|
+
flwr/cli/new/templates/app/README.flowertune.md.tpl,sha256=lxr_RCGfiDy8QGcMVdjXsUXWM_gLf6cY7UQanGL_FFQ,3304
|
|
13
13
|
flwr/cli/new/templates/app/README.md.tpl,sha256=t7w4YFZEcJOxAnuJmNPw5-fDdIJu7PfLd8gFJDiBwwo,436
|
|
14
14
|
flwr/cli/new/templates/app/__init__.py,sha256=DU7QMY7IhMQyuwm_tja66xU0KXTWQFqzfTqwg-_NJdE,729
|
|
15
15
|
flwr/cli/new/templates/app/code/__init__.py,sha256=EM6vfvgAILKPaPn7H1wMV1Wi01WyZCP_Eg6NxD6oWg8,736
|
|
@@ -22,13 +22,11 @@ flwr/cli/new/templates/app/code/client.pytorch.py.tpl,sha256=WczaR5avJUhfw2Grn2K
|
|
|
22
22
|
flwr/cli/new/templates/app/code/client.sklearn.py.tpl,sha256=xW9cuKhybk5S8IeDZhbeb0DNegDIJGEYrzMKsxgc2GE,2978
|
|
23
23
|
flwr/cli/new/templates/app/code/client.tensorflow.py.tpl,sha256=u3KKf7hC9xGqOIUJXYCHJ_jiIu3aVbsC8pxVxm4yN6I,1759
|
|
24
24
|
flwr/cli/new/templates/app/code/flwr_tune/__init__.py,sha256=JgNgBtKdm1jKM9625WxappCAVUGtYAmcjKSsXJ1u3ZQ,748
|
|
25
|
-
flwr/cli/new/templates/app/code/flwr_tune/
|
|
26
|
-
flwr/cli/new/templates/app/code/flwr_tune/
|
|
27
|
-
flwr/cli/new/templates/app/code/flwr_tune/
|
|
28
|
-
flwr/cli/new/templates/app/code/flwr_tune/
|
|
29
|
-
flwr/cli/new/templates/app/code/flwr_tune/
|
|
30
|
-
flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl,sha256=tz4hgAGV6pn5cOFW10ELRNRsUzLBSssHxrH_gSs_jtk,1552
|
|
31
|
-
flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl,sha256=cBPpBVN_N7p4T2a3rqChlngmE0dB_jveOLHesNcEHvs,268
|
|
25
|
+
flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl,sha256=DbotzaXzLDwplVBkJLOe5Lt5b6Yutwv9rJ69oVwyrvU,4397
|
|
26
|
+
flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl,sha256=iAujo8WubDGrz0gg_6zl-TUvkIbNRJM-VJmwKJ9tGY8,3051
|
|
27
|
+
flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl,sha256=UCLEKUpXarkz9tMFtDrxmLv6QuKe5zCimTuoopQedUM,1717
|
|
28
|
+
flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl,sha256=yMVcbfGkTPV9AV16bVdi5fTX1a6jmtszTUrvLXSosio,3305
|
|
29
|
+
flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl,sha256=BhiqRg9w1MGuU5h2_vrLhRc0oHItYzE69qX_JI411k8,2754
|
|
32
30
|
flwr/cli/new/templates/app/code/server.huggingface.py.tpl,sha256=etpjLvGu6pVXzYQBKZp4tTbD3zm461qFo24NliKo74U,591
|
|
33
31
|
flwr/cli/new/templates/app/code/server.jax.py.tpl,sha256=pIdUH-LgWRAGWQYLlivMNf8XnDSNDe2cCuRjlxbRzys,529
|
|
34
32
|
flwr/cli/new/templates/app/code/server.mlx.py.tpl,sha256=RqiZ0k468SOlm9dcPr-fvA8xcWv4zwDCbJfBwL7P9Us,529
|
|
@@ -41,7 +39,7 @@ flwr/cli/new/templates/app/code/task.jax.py.tpl,sha256=F05eg149c9icRyVNdfcLyZvAX
|
|
|
41
39
|
flwr/cli/new/templates/app/code/task.mlx.py.tpl,sha256=jWtCULLRr_9bCIJvoTLMx037-SDl_LF8udtA1UGoXDk,2946
|
|
42
40
|
flwr/cli/new/templates/app/code/task.pytorch.py.tpl,sha256=NgbPix74X1t3ybaGjqdls30vF1i5oY3L7EQExhWhN74,3812
|
|
43
41
|
flwr/cli/new/templates/app/code/task.tensorflow.py.tpl,sha256=SKXAZdgBnPpbAbJ90Rb7oQ5ilnopBx_j_JNFoUDeEAI,1732
|
|
44
|
-
flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl,sha256=
|
|
42
|
+
flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl,sha256=pogRZLrwSfN_XH4NxDdMkhMh1O_7DP90VOoP-cP0HvI,1827
|
|
45
43
|
flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl,sha256=nD0rRUyr_Cj0TaSH8PsiaMhCwu_BuOVX4oqWfFSvOcE,765
|
|
46
44
|
flwr/cli/new/templates/app/pyproject.jax.toml.tpl,sha256=Tq6jeGcoOKzMwWWYxMVnzMcipLURHLiW69iYlD1ywMg,659
|
|
47
45
|
flwr/cli/new/templates/app/pyproject.mlx.toml.tpl,sha256=SHwYAA2qgIlOAU3Sb9BKSZcZ7O9biACg27MHexXUtDw,741
|
|
@@ -50,7 +48,7 @@ flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=vIO1ArukTC76ogYLNmJ
|
|
|
50
48
|
flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl,sha256=jk_5teoyOVM9QdBea8J-nk10S6TKw81QZiiKB54ATF0,654
|
|
51
49
|
flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl,sha256=bRIvPCPvTTI4Eo5b61Rmw8WdDw3sjcohciTXgULN5l8,702
|
|
52
50
|
flwr/cli/run/__init__.py,sha256=oCd6HmQDx-sqver1gecgx-uMA38BLTSiiKpl7RGNceg,789
|
|
53
|
-
flwr/cli/run/run.py,sha256=
|
|
51
|
+
flwr/cli/run/run.py,sha256=RI6MgLBNYxmacjQg8XMAQ7VKxbV0DkRyJTfe4GsDFuw,7979
|
|
54
52
|
flwr/cli/utils.py,sha256=l65Ul0YsSBPuypk0uorAtEDmLEYiUrzpCXi6zCg9mJ4,4506
|
|
55
53
|
flwr/client/__init__.py,sha256=wzJZsYJIHf_8-PMzvfbinyzzjgh1UP1vLrAw2_yEbKI,1345
|
|
56
54
|
flwr/client/app.py,sha256=o_2bhmlBeZATtWnAPZhL-Q1Ly0QZxc9ou4i7t0HKumE,31956
|
|
@@ -286,8 +284,8 @@ flwr/superexec/exec_grpc.py,sha256=PhqGoZEpTMxSQmUSV8Wgtzb1Za_pHJ-adZqo5RYnDyE,1
|
|
|
286
284
|
flwr/superexec/exec_servicer.py,sha256=jl0aKVjm0PLQABcTL5c3jdSIzb0Z6hpVOtrAn4Ob7ts,2323
|
|
287
285
|
flwr/superexec/executor.py,sha256=k_adivto6R2U82DADOHNvdtobehBYreRek1gOEBIQnQ,2318
|
|
288
286
|
flwr/superexec/simulation.py,sha256=J6pw-RqCSiUed8I_3MasZH4tl57ZmDebPAHNnbb0-vE,7420
|
|
289
|
-
flwr_nightly-1.11.0.
|
|
290
|
-
flwr_nightly-1.11.0.
|
|
291
|
-
flwr_nightly-1.11.0.
|
|
292
|
-
flwr_nightly-1.11.0.
|
|
293
|
-
flwr_nightly-1.11.0.
|
|
287
|
+
flwr_nightly-1.11.0.dev20240827.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
288
|
+
flwr_nightly-1.11.0.dev20240827.dist-info/METADATA,sha256=EzNACXTNbj22gGXrQbuOE8JeQfD7NxYOHb-FGCLOBTQ,15703
|
|
289
|
+
flwr_nightly-1.11.0.dev20240827.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
|
|
290
|
+
flwr_nightly-1.11.0.dev20240827.dist-info/entry_points.txt,sha256=3cDQVJEBRCSLzJrVYAgjXpoCjuQ74I3A9NZ61DOHdVo,388
|
|
291
|
+
flwr_nightly-1.11.0.dev20240827.dist-info/RECORD,,
|
|
@@ -1,89 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
import warnings
|
|
5
|
-
from datetime import datetime
|
|
6
|
-
|
|
7
|
-
from flwr_datasets import FederatedDataset
|
|
8
|
-
from hydra import compose, initialize
|
|
9
|
-
from hydra.utils import instantiate
|
|
10
|
-
|
|
11
|
-
from flwr.client import ClientApp
|
|
12
|
-
from flwr.common import Context, ndarrays_to_parameters
|
|
13
|
-
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
14
|
-
|
|
15
|
-
from $import_name.client_app import gen_client_fn, get_parameters
|
|
16
|
-
from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
|
|
17
|
-
from $import_name.models import get_model
|
|
18
|
-
from $import_name.server_app import fit_weighted_average, get_evaluate_fn, get_on_fit_config
|
|
19
|
-
|
|
20
|
-
# Avoid warnings
|
|
21
|
-
warnings.filterwarnings("ignore", category=UserWarning)
|
|
22
|
-
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
23
|
-
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
|
|
24
|
-
|
|
25
|
-
# Initialise regular config
|
|
26
|
-
with initialize(config_path="conf", version_base="1.1"):
|
|
27
|
-
cfg = compose(config_name="config")
|
|
28
|
-
|
|
29
|
-
# Initialise static config
|
|
30
|
-
with initialize(config_path="conf", version_base="1.1"):
|
|
31
|
-
cfg_static = compose(config_name="static_config")
|
|
32
|
-
|
|
33
|
-
cfg.train.num_rounds = cfg_static.num_rounds
|
|
34
|
-
|
|
35
|
-
# Create output directory given current timestamp
|
|
36
|
-
current_time = datetime.now()
|
|
37
|
-
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
|
|
38
|
-
save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
|
|
39
|
-
os.makedirs(save_path, exist_ok=True)
|
|
40
|
-
|
|
41
|
-
# Partition dataset and get dataloaders
|
|
42
|
-
partitioner = instantiate(cfg_static.partitioner)
|
|
43
|
-
fds = FederatedDataset(
|
|
44
|
-
dataset=cfg_static.dataset.name, partitioners={"train": partitioner}
|
|
45
|
-
)
|
|
46
|
-
(
|
|
47
|
-
tokenizer,
|
|
48
|
-
data_collator,
|
|
49
|
-
formatting_prompts_func,
|
|
50
|
-
) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)
|
|
51
|
-
|
|
52
|
-
# ClientApp for Flower Next
|
|
53
|
-
client = ClientApp(
|
|
54
|
-
client_fn=gen_client_fn(
|
|
55
|
-
fds,
|
|
56
|
-
tokenizer,
|
|
57
|
-
formatting_prompts_func,
|
|
58
|
-
data_collator,
|
|
59
|
-
cfg.model,
|
|
60
|
-
cfg.train,
|
|
61
|
-
save_path,
|
|
62
|
-
),
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
# Get initial model weights
|
|
66
|
-
init_model = get_model(cfg.model)
|
|
67
|
-
init_model_parameters = get_parameters(init_model)
|
|
68
|
-
init_model_parameters = ndarrays_to_parameters(init_model_parameters)
|
|
69
|
-
|
|
70
|
-
def server_fn(context: Context):
|
|
71
|
-
# Instantiate strategy according to config. Here we pass other arguments
|
|
72
|
-
# that are only defined at runtime.
|
|
73
|
-
strategy = instantiate(
|
|
74
|
-
cfg.strategy,
|
|
75
|
-
on_fit_config_fn=get_on_fit_config(),
|
|
76
|
-
fit_metrics_aggregation_fn=fit_weighted_average,
|
|
77
|
-
initial_parameters=init_model_parameters,
|
|
78
|
-
evaluate_fn=get_evaluate_fn(
|
|
79
|
-
cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
|
|
80
|
-
),
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
config = ServerConfig(num_rounds=cfg_static.num_rounds)
|
|
84
|
-
|
|
85
|
-
return ServerAppComponents(strategy=strategy, config=config)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
# ServerApp for Flower Next
|
|
89
|
-
server = ServerApp(server_fn=server_fn)
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
# Federated Instruction Tuning
|
|
2
|
-
---
|
|
3
|
-
model:
|
|
4
|
-
name: "mistralai/Mistral-7B-v0.3"
|
|
5
|
-
quantization: 4 # 8 or 4 if you want to do quantization with BitsAndBytes
|
|
6
|
-
gradient_checkpointing: True
|
|
7
|
-
lora:
|
|
8
|
-
peft_lora_r: 32
|
|
9
|
-
peft_lora_alpha: 64
|
|
10
|
-
|
|
11
|
-
train:
|
|
12
|
-
num_rounds: null
|
|
13
|
-
save_every_round: 5
|
|
14
|
-
learning_rate_max: 5e-5
|
|
15
|
-
learning_rate_min: 1e-6
|
|
16
|
-
seq_length: 512
|
|
17
|
-
training_arguments:
|
|
18
|
-
output_dir: null # to be set by hydra
|
|
19
|
-
learning_rate: null # to be set by the client
|
|
20
|
-
per_device_train_batch_size: 16
|
|
21
|
-
gradient_accumulation_steps: 1
|
|
22
|
-
logging_steps: 10
|
|
23
|
-
num_train_epochs: 3
|
|
24
|
-
max_steps: 10
|
|
25
|
-
report_to: null
|
|
26
|
-
save_steps: 1000
|
|
27
|
-
save_total_limit: 10
|
|
28
|
-
gradient_checkpointing: True
|
|
29
|
-
lr_scheduler_type: "constant"
|
|
30
|
-
|
|
31
|
-
strategy:
|
|
32
|
-
_target_: flwr.server.strategy.FedAvg
|
|
33
|
-
fraction_fit: $fraction_fit
|
|
34
|
-
fraction_evaluate: 0.0 # no client evaluation
|
|
@@ -1,48 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
-
|
|
3
|
-
from $import_name.client_app import set_parameters
|
|
4
|
-
from $import_name.models import get_model
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
# Get function that will be executed by the strategy's evaluate() method
|
|
8
|
-
# Here we use it to save global model checkpoints
|
|
9
|
-
def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
|
|
10
|
-
"""Return an evaluation function for saving global model."""
|
|
11
|
-
|
|
12
|
-
def evaluate(server_round: int, parameters, config):
|
|
13
|
-
# Save model
|
|
14
|
-
if server_round != 0 and (
|
|
15
|
-
server_round == total_round or server_round % save_every_round == 0
|
|
16
|
-
):
|
|
17
|
-
# Init model
|
|
18
|
-
model = get_model(model_cfg)
|
|
19
|
-
set_parameters(model, parameters)
|
|
20
|
-
|
|
21
|
-
model.save_pretrained(f"{save_path}/peft_{server_round}")
|
|
22
|
-
|
|
23
|
-
return 0.0, {}
|
|
24
|
-
|
|
25
|
-
return evaluate
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def get_on_fit_config():
|
|
29
|
-
"""
|
|
30
|
-
Return a function that will be used to construct the config
|
|
31
|
-
that the client's fit() method will receive.
|
|
32
|
-
"""
|
|
33
|
-
|
|
34
|
-
def fit_config_fn(server_round: int):
|
|
35
|
-
fit_config = {"current_round": server_round}
|
|
36
|
-
return fit_config
|
|
37
|
-
|
|
38
|
-
return fit_config_fn
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def fit_weighted_average(metrics):
|
|
42
|
-
"""Aggregate (federated) evaluation metrics."""
|
|
43
|
-
# Multiply accuracy of each client by number of examples used
|
|
44
|
-
losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
|
|
45
|
-
examples = [num_examples for num_examples, _ in metrics]
|
|
46
|
-
|
|
47
|
-
# Aggregate and return custom metric (weighted average)
|
|
48
|
-
return {"train_loss": sum(losses) / sum(examples)}
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
# Federated Instruction Tuning (static)
|
|
2
|
-
---
|
|
3
|
-
dataset:
|
|
4
|
-
name: $dataset_name
|
|
5
|
-
|
|
6
|
-
# FL experimental settings
|
|
7
|
-
num_clients: $num_clients # total number of clients
|
|
8
|
-
num_rounds: 200
|
|
9
|
-
partitioner:
|
|
10
|
-
_target_: flwr_datasets.partitioner.IidPartitioner
|
|
11
|
-
num_partitions: $num_clients
|
{flwr_nightly-1.11.0.dev20240826.dist-info → flwr_nightly-1.11.0.dev20240827.dist-info}/LICENSE
RENAMED
|
File without changes
|
{flwr_nightly-1.11.0.dev20240826.dist-info → flwr_nightly-1.11.0.dev20240827.dist-info}/WHEEL
RENAMED
|
File without changes
|
|
File without changes
|