flwr 1.24.0__py3-none-any.whl → 1.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app_cmd/review.py +13 -3
- flwr/cli/federation/show.py +4 -3
- flwr/cli/ls.py +44 -3
- flwr/cli/new/new.py +106 -297
- flwr/cli/run/run.py +12 -17
- flwr/cli/run_utils.py +23 -5
- flwr/cli/stop.py +1 -1
- flwr/cli/supernode/ls.py +10 -5
- flwr/cli/utils.py +0 -137
- flwr/client/grpc_adapter_client/connection.py +2 -2
- flwr/client/grpc_rere_client/connection.py +6 -3
- flwr/client/rest_client/connection.py +6 -4
- flwr/common/serde.py +6 -0
- flwr/common/typing.py +6 -0
- flwr/proto/fleet_pb2.py +10 -10
- flwr/proto/fleet_pb2.pyi +5 -1
- flwr/proto/run_pb2.py +24 -24
- flwr/proto/run_pb2.pyi +10 -1
- flwr/server/app.py +1 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +41 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +34 -0
- flwr/server/superlink/linkstate/linkstate.py +32 -0
- flwr/server/superlink/linkstate/sqlite_linkstate.py +60 -3
- flwr/supercore/constant.py +3 -0
- flwr/supercore/utils.py +190 -0
- flwr/superlink/servicer/control/control_grpc.py +2 -0
- flwr/superlink/servicer/control/control_servicer.py +88 -5
- flwr/supernode/nodestate/in_memory_nodestate.py +62 -1
- flwr/supernode/nodestate/nodestate.py +45 -0
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +7 -1
- flwr/supernode/start_client_internal.py +7 -4
- {flwr-1.24.0.dist-info → flwr-1.25.0.dist-info}/METADATA +2 -4
- {flwr-1.24.0.dist-info → flwr-1.25.0.dist-info}/RECORD +35 -96
- flwr/cli/new/templates/__init__.py +0 -15
- flwr/cli/new/templates/app/.gitignore.tpl +0 -163
- flwr/cli/new/templates/app/LICENSE.tpl +0 -202
- flwr/cli/new/templates/app/README.baseline.md.tpl +0 -127
- flwr/cli/new/templates/app/README.flowertune.md.tpl +0 -68
- flwr/cli/new/templates/app/README.md.tpl +0 -37
- flwr/cli/new/templates/app/__init__.py +0 -15
- flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/code/__init__.py +0 -15
- flwr/cli/new/templates/app/code/__init__.py.tpl +0 -1
- flwr/cli/new/templates/app/code/__init__.pytorch_legacy_api.py.tpl +0 -1
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +0 -75
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +0 -93
- flwr/cli/new/templates/app/code/client.jax.py.tpl +0 -71
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +0 -102
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +0 -46
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +0 -80
- flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +0 -55
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +0 -108
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -82
- flwr/cli/new/templates/app/code/client.xgboost.py.tpl +0 -110
- flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +0 -36
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +0 -15
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +0 -92
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +0 -87
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -56
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +0 -73
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +0 -78
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -66
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +0 -43
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +0 -42
- flwr/cli/new/templates/app/code/server.jax.py.tpl +0 -39
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +0 -41
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +0 -38
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +0 -41
- flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +0 -31
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +0 -44
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +0 -38
- flwr/cli/new/templates/app/code/server.xgboost.py.tpl +0 -56
- flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +0 -98
- flwr/cli/new/templates/app/code/task.jax.py.tpl +0 -57
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +0 -102
- flwr/cli/new/templates/app/code/task.numpy.py.tpl +0 -7
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +0 -99
- flwr/cli/new/templates/app/code/task.pytorch_legacy_api.py.tpl +0 -111
- flwr/cli/new/templates/app/code/task.sklearn.py.tpl +0 -67
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +0 -52
- flwr/cli/new/templates/app/code/task.xgboost.py.tpl +0 -67
- flwr/cli/new/templates/app/code/utils.baseline.py.tpl +0 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +0 -146
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +0 -80
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +0 -65
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +0 -52
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +0 -56
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +0 -49
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +0 -52
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +0 -53
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +0 -61
- {flwr-1.24.0.dist-info → flwr-1.25.0.dist-info}/WHEEL +0 -0
- {flwr-1.24.0.dist-info → flwr-1.25.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,82 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
|
4
|
-
from flwr.clientapp import ClientApp
|
|
5
|
-
|
|
6
|
-
from $import_name.task import load_data, load_model
|
|
7
|
-
|
|
8
|
-
# Flower ClientApp
|
|
9
|
-
app = ClientApp()
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
@app.train()
|
|
13
|
-
def train(msg: Message, context: Context):
|
|
14
|
-
"""Train the model on local data."""
|
|
15
|
-
|
|
16
|
-
# Load the model and initialize it with the received weights
|
|
17
|
-
model = load_model()
|
|
18
|
-
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
|
19
|
-
model.set_weights(ndarrays)
|
|
20
|
-
|
|
21
|
-
# Read from config
|
|
22
|
-
epochs = context.run_config["local-epochs"]
|
|
23
|
-
batch_size = context.run_config["batch-size"]
|
|
24
|
-
verbose = context.run_config.get("verbose")
|
|
25
|
-
|
|
26
|
-
# Load the data
|
|
27
|
-
partition_id = context.node_config["partition-id"]
|
|
28
|
-
num_partitions = context.node_config["num-partitions"]
|
|
29
|
-
x_train, y_train, _, _ = load_data(partition_id, num_partitions)
|
|
30
|
-
|
|
31
|
-
# Train the model on local data
|
|
32
|
-
history = model.fit(
|
|
33
|
-
x_train,
|
|
34
|
-
y_train,
|
|
35
|
-
epochs=epochs,
|
|
36
|
-
batch_size=batch_size,
|
|
37
|
-
verbose=verbose,
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
# Get final training loss and accuracy
|
|
41
|
-
train_loss = history.history["loss"][-1] if "loss" in history.history else None
|
|
42
|
-
train_acc = history.history.get("accuracy")
|
|
43
|
-
train_acc = train_acc[-1] if train_acc is not None else None
|
|
44
|
-
|
|
45
|
-
# Construct and return reply Message
|
|
46
|
-
model_record = ArrayRecord(model.get_weights())
|
|
47
|
-
metrics = {"num-examples": len(x_train)}
|
|
48
|
-
if train_loss is not None:
|
|
49
|
-
metrics["train_loss"] = train_loss
|
|
50
|
-
if train_acc is not None:
|
|
51
|
-
metrics["train_acc"] = train_acc
|
|
52
|
-
metric_record = MetricRecord(metrics)
|
|
53
|
-
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
|
54
|
-
return Message(content=content, reply_to=msg)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
@app.evaluate()
|
|
58
|
-
def evaluate(msg: Message, context: Context):
|
|
59
|
-
"""Evaluate the model on local data."""
|
|
60
|
-
|
|
61
|
-
# Load the model and initialize it with the received weights
|
|
62
|
-
model = load_model()
|
|
63
|
-
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
|
|
64
|
-
model.set_weights(ndarrays)
|
|
65
|
-
|
|
66
|
-
# Load the data
|
|
67
|
-
partition_id = context.node_config["partition-id"]
|
|
68
|
-
num_partitions = context.node_config["num-partitions"]
|
|
69
|
-
_, _, x_test, y_test = load_data(partition_id, num_partitions)
|
|
70
|
-
|
|
71
|
-
# Evaluate the model on local data
|
|
72
|
-
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
|
|
73
|
-
|
|
74
|
-
# Construct and return reply Message
|
|
75
|
-
metrics = {
|
|
76
|
-
"eval_loss": loss,
|
|
77
|
-
"eval_acc": accuracy,
|
|
78
|
-
"num-examples": len(x_test),
|
|
79
|
-
}
|
|
80
|
-
metric_record = MetricRecord(metrics)
|
|
81
|
-
content = RecordDict({"metrics": metric_record})
|
|
82
|
-
return Message(content=content, reply_to=msg)
|
|
@@ -1,110 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / $framework_str app."""
|
|
2
|
-
|
|
3
|
-
import warnings
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
6
|
-
import xgboost as xgb
|
|
7
|
-
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
|
8
|
-
from flwr.clientapp import ClientApp
|
|
9
|
-
from flwr.common.config import unflatten_dict
|
|
10
|
-
|
|
11
|
-
from $import_name.task import load_data, replace_keys
|
|
12
|
-
|
|
13
|
-
warnings.filterwarnings("ignore", category=UserWarning)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
# Flower ClientApp
|
|
17
|
-
app = ClientApp()
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def _local_boost(bst_input, num_local_round, train_dmatrix):
|
|
21
|
-
# Update trees based on local training data.
|
|
22
|
-
for i in range(num_local_round):
|
|
23
|
-
bst_input.update(train_dmatrix, bst_input.num_boosted_rounds())
|
|
24
|
-
|
|
25
|
-
# Bagging: extract the last N=num_local_round trees for sever aggregation
|
|
26
|
-
bst = bst_input[
|
|
27
|
-
bst_input.num_boosted_rounds()
|
|
28
|
-
- num_local_round : bst_input.num_boosted_rounds()
|
|
29
|
-
]
|
|
30
|
-
return bst
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
@app.train()
|
|
34
|
-
def train(msg: Message, context: Context) -> Message:
|
|
35
|
-
# Load model and data
|
|
36
|
-
partition_id = context.node_config["partition-id"]
|
|
37
|
-
num_partitions = context.node_config["num-partitions"]
|
|
38
|
-
train_dmatrix, _, num_train, _ = load_data(partition_id, num_partitions)
|
|
39
|
-
|
|
40
|
-
# Read from run config
|
|
41
|
-
num_local_round = context.run_config["local-epochs"]
|
|
42
|
-
# Flatted config dict and replace "-" with "_"
|
|
43
|
-
cfg = replace_keys(unflatten_dict(context.run_config))
|
|
44
|
-
params = cfg["params"]
|
|
45
|
-
|
|
46
|
-
global_round = msg.content["config"]["server-round"]
|
|
47
|
-
if global_round == 1:
|
|
48
|
-
# First round local training
|
|
49
|
-
bst = xgb.train(
|
|
50
|
-
params,
|
|
51
|
-
train_dmatrix,
|
|
52
|
-
num_boost_round=num_local_round,
|
|
53
|
-
)
|
|
54
|
-
else:
|
|
55
|
-
bst = xgb.Booster(params=params)
|
|
56
|
-
global_model = bytearray(msg.content["arrays"]["0"].numpy().tobytes())
|
|
57
|
-
|
|
58
|
-
# Load global model into booster
|
|
59
|
-
bst.load_model(global_model)
|
|
60
|
-
|
|
61
|
-
# Local training
|
|
62
|
-
bst = _local_boost(bst, num_local_round, train_dmatrix)
|
|
63
|
-
|
|
64
|
-
# Save model
|
|
65
|
-
local_model = bst.save_raw("json")
|
|
66
|
-
model_np = np.frombuffer(local_model, dtype=np.uint8)
|
|
67
|
-
|
|
68
|
-
# Construct reply message
|
|
69
|
-
# Note: we store the model as the first item in a list into ArrayRecord,
|
|
70
|
-
# which can be accessed using index ["0"].
|
|
71
|
-
model_record = ArrayRecord([model_np])
|
|
72
|
-
metrics = {
|
|
73
|
-
"num-examples": num_train,
|
|
74
|
-
}
|
|
75
|
-
metric_record = MetricRecord(metrics)
|
|
76
|
-
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
|
77
|
-
return Message(content=content, reply_to=msg)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
@app.evaluate()
|
|
81
|
-
def evaluate(msg: Message, context: Context) -> Message:
|
|
82
|
-
# Load model and data
|
|
83
|
-
partition_id = context.node_config["partition-id"]
|
|
84
|
-
num_partitions = context.node_config["num-partitions"]
|
|
85
|
-
_, valid_dmatrix, _, num_val = load_data(partition_id, num_partitions)
|
|
86
|
-
|
|
87
|
-
# Load config
|
|
88
|
-
cfg = replace_keys(unflatten_dict(context.run_config))
|
|
89
|
-
params = cfg["params"]
|
|
90
|
-
|
|
91
|
-
# Load global model
|
|
92
|
-
bst = xgb.Booster(params=params)
|
|
93
|
-
global_model = bytearray(msg.content["arrays"]["0"].numpy().tobytes())
|
|
94
|
-
bst.load_model(global_model)
|
|
95
|
-
|
|
96
|
-
# Run evaluation
|
|
97
|
-
eval_results = bst.eval_set(
|
|
98
|
-
evals=[(valid_dmatrix, "valid")],
|
|
99
|
-
iteration=bst.num_boosted_rounds() - 1,
|
|
100
|
-
)
|
|
101
|
-
auc = float(eval_results.split("\t")[1].split(":")[1])
|
|
102
|
-
|
|
103
|
-
# Construct and return reply Message
|
|
104
|
-
metrics = {
|
|
105
|
-
"auc": auc,
|
|
106
|
-
"num-examples": num_val,
|
|
107
|
-
}
|
|
108
|
-
metric_record = MetricRecord(metrics)
|
|
109
|
-
content = RecordDict({"metrics": metric_record})
|
|
110
|
-
return Message(content=content, reply_to=msg)
|
|
@@ -1,36 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower Baseline."""
|
|
2
|
-
|
|
3
|
-
from flwr_datasets import FederatedDataset
|
|
4
|
-
from flwr_datasets.partitioner import IidPartitioner
|
|
5
|
-
from torch.utils.data import DataLoader
|
|
6
|
-
from torchvision.transforms import Compose, Normalize, ToTensor
|
|
7
|
-
|
|
8
|
-
FDS = None # Cache FederatedDataset
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def load_data(partition_id: int, num_partitions: int):
|
|
12
|
-
"""Load partition CIFAR10 data."""
|
|
13
|
-
# Only initialize `FederatedDataset` once
|
|
14
|
-
global FDS # pylint: disable=global-statement
|
|
15
|
-
if FDS is None:
|
|
16
|
-
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
17
|
-
FDS = FederatedDataset(
|
|
18
|
-
dataset="uoft-cs/cifar10",
|
|
19
|
-
partitioners={"train": partitioner},
|
|
20
|
-
)
|
|
21
|
-
partition = FDS.load_partition(partition_id)
|
|
22
|
-
# Divide data on each node: 80% train, 20% test
|
|
23
|
-
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
24
|
-
pytorch_transforms = Compose(
|
|
25
|
-
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
def apply_transforms(batch):
|
|
29
|
-
"""Apply transforms to the partition from FederatedDataset."""
|
|
30
|
-
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
|
|
31
|
-
return batch
|
|
32
|
-
|
|
33
|
-
partition_train_test = partition_train_test.with_transform(apply_transforms)
|
|
34
|
-
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
|
|
35
|
-
testloader = DataLoader(partition_train_test["test"], batch_size=32)
|
|
36
|
-
return trainloader, testloader
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
"""Flower CLI `new` command app / code / flwr_tune templates."""
|
|
@@ -1,92 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
import warnings
|
|
5
|
-
|
|
6
|
-
from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
|
|
7
|
-
from flwr.clientapp import ClientApp
|
|
8
|
-
from flwr.common.config import unflatten_dict
|
|
9
|
-
from omegaconf import DictConfig
|
|
10
|
-
from peft import get_peft_model_state_dict, set_peft_model_state_dict
|
|
11
|
-
from transformers import TrainingArguments
|
|
12
|
-
from trl import SFTTrainer
|
|
13
|
-
|
|
14
|
-
from $import_name.dataset import (
|
|
15
|
-
get_tokenizer_and_data_collator_and_propt_formatting,
|
|
16
|
-
load_data,
|
|
17
|
-
replace_keys,
|
|
18
|
-
)
|
|
19
|
-
from $import_name.models import cosine_annealing, get_model
|
|
20
|
-
|
|
21
|
-
# Avoid warnings
|
|
22
|
-
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
23
|
-
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
|
|
24
|
-
warnings.filterwarnings("ignore", category=UserWarning)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
# Avoid warnings
|
|
28
|
-
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
29
|
-
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
|
|
30
|
-
warnings.filterwarnings("ignore", category=UserWarning)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
# Flower ClientApp
|
|
34
|
-
app = ClientApp()
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
@app.train()
|
|
38
|
-
def train(msg: Message, context: Context):
|
|
39
|
-
"""Train the model on local data."""
|
|
40
|
-
# Parse config
|
|
41
|
-
partition_id = context.node_config["partition-id"]
|
|
42
|
-
num_partitions = context.node_config["num-partitions"]
|
|
43
|
-
num_rounds = context.run_config["num-server-rounds"]
|
|
44
|
-
cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
|
|
45
|
-
training_arguments = TrainingArguments(**cfg.train.training_arguments)
|
|
46
|
-
|
|
47
|
-
# Let's get the client partition
|
|
48
|
-
trainset = load_data(partition_id, num_partitions, cfg.static.dataset.name)
|
|
49
|
-
(
|
|
50
|
-
tokenizer,
|
|
51
|
-
data_collator,
|
|
52
|
-
formatting_prompts_func,
|
|
53
|
-
) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)
|
|
54
|
-
|
|
55
|
-
# Load the model and initialize it with the received weights
|
|
56
|
-
model = get_model(cfg.model)
|
|
57
|
-
set_peft_model_state_dict(model, msg.content["arrays"].to_torch_state_dict())
|
|
58
|
-
|
|
59
|
-
# Set learning rate for current round
|
|
60
|
-
new_lr = cosine_annealing(
|
|
61
|
-
msg.content["config"]["server-round"],
|
|
62
|
-
num_rounds,
|
|
63
|
-
cfg.train.learning_rate_max,
|
|
64
|
-
cfg.train.learning_rate_min,
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
training_arguments.learning_rate = new_lr
|
|
68
|
-
training_arguments.output_dir = msg.content["config"]["save_path"]
|
|
69
|
-
|
|
70
|
-
# Construct trainer
|
|
71
|
-
trainer = SFTTrainer(
|
|
72
|
-
model=model,
|
|
73
|
-
tokenizer=tokenizer,
|
|
74
|
-
args=training_arguments,
|
|
75
|
-
max_seq_length=cfg.train.seq_length,
|
|
76
|
-
train_dataset=trainset,
|
|
77
|
-
formatting_func=formatting_prompts_func,
|
|
78
|
-
data_collator=data_collator,
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
# Do local training
|
|
82
|
-
results = trainer.train()
|
|
83
|
-
|
|
84
|
-
# Construct and return reply Message
|
|
85
|
-
model_record = ArrayRecord(get_peft_model_state_dict(model))
|
|
86
|
-
metrics = {
|
|
87
|
-
"train_loss": results.training_loss,
|
|
88
|
-
"num-examples": len(trainset),
|
|
89
|
-
}
|
|
90
|
-
metric_record = MetricRecord(metrics)
|
|
91
|
-
content = RecordDict({"arrays": model_record, "metrics": metric_record})
|
|
92
|
-
return Message(content=content, reply_to=msg)
|
|
@@ -1,87 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
-
|
|
3
|
-
from flwr_datasets import FederatedDataset
|
|
4
|
-
from flwr_datasets.partitioner import IidPartitioner
|
|
5
|
-
from transformers import AutoTokenizer
|
|
6
|
-
from trl import DataCollatorForCompletionOnlyLM
|
|
7
|
-
|
|
8
|
-
FDS = None # Cache FederatedDataset
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def formatting_prompts_func(example):
|
|
12
|
-
"""Construct prompts."""
|
|
13
|
-
output_texts = []
|
|
14
|
-
# Constructing a standard Alpaca
|
|
15
|
-
# (https://github.com/tatsu-lab/stanford_alpaca#data-release) prompt
|
|
16
|
-
mssg = (
|
|
17
|
-
"Below is an instruction that describes a task. "
|
|
18
|
-
"Write a response that appropriately completes the request."
|
|
19
|
-
)
|
|
20
|
-
for i in range(len(example["instruction"])):
|
|
21
|
-
text = (
|
|
22
|
-
f"{mssg}\n### Instruction:\n{example['instruction'][i]}\n"
|
|
23
|
-
f"### Response: {example['response'][i]}"
|
|
24
|
-
)
|
|
25
|
-
output_texts.append(text)
|
|
26
|
-
return output_texts
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def get_tokenizer_and_data_collator_and_propt_formatting(model_name: str):
|
|
30
|
-
"""Get tokenizer, data_collator and prompt formatting."""
|
|
31
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
32
|
-
model_name, use_fast=True, padding_side="right"
|
|
33
|
-
)
|
|
34
|
-
tokenizer.pad_token = tokenizer.eos_token
|
|
35
|
-
response_template_with_context = "\n### Response:" # alpaca response tag
|
|
36
|
-
response_template_ids = tokenizer.encode(
|
|
37
|
-
response_template_with_context, add_special_tokens=False
|
|
38
|
-
)[2:]
|
|
39
|
-
data_collator = DataCollatorForCompletionOnlyLM(
|
|
40
|
-
response_template_ids, tokenizer=tokenizer
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
return tokenizer, data_collator, formatting_prompts_func
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def formatting(dataset):
|
|
47
|
-
"""Format dataset."""
|
|
48
|
-
dataset["instruction"] = dataset["instruction"] + " " + dataset["input"]
|
|
49
|
-
return dataset
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def reformat(dataset, llm_task):
|
|
53
|
-
"""Reformat datasets."""
|
|
54
|
-
dataset = dataset.rename_column("output", "response")
|
|
55
|
-
if llm_task in ["finance", "code"]:
|
|
56
|
-
dataset = dataset.map(formatting, remove_columns=["input"])
|
|
57
|
-
if llm_task == "medical":
|
|
58
|
-
dataset = dataset.remove_columns(["instruction"])
|
|
59
|
-
dataset = dataset.rename_column("input", "instruction")
|
|
60
|
-
return dataset
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def load_data(partition_id: int, num_partitions: int, dataset_name: str):
|
|
64
|
-
"""Load partition data."""
|
|
65
|
-
# Only initialize `FederatedDataset` once
|
|
66
|
-
global FDS
|
|
67
|
-
if FDS is None:
|
|
68
|
-
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
69
|
-
FDS = FederatedDataset(
|
|
70
|
-
dataset=dataset_name,
|
|
71
|
-
partitioners={"train": partitioner},
|
|
72
|
-
)
|
|
73
|
-
client_trainset = FDS.load_partition(partition_id, "train")
|
|
74
|
-
client_trainset = reformat(client_trainset, llm_task="$llm_challenge_str")
|
|
75
|
-
return client_trainset
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def replace_keys(input_dict, match="-", target="_"):
|
|
79
|
-
"""Recursively replace match string with target string in dictionary keys."""
|
|
80
|
-
new_dict = {}
|
|
81
|
-
for key, value in input_dict.items():
|
|
82
|
-
new_key = key.replace(match, target)
|
|
83
|
-
if isinstance(value, dict):
|
|
84
|
-
new_dict[new_key] = replace_keys(value, match, target)
|
|
85
|
-
else:
|
|
86
|
-
new_dict[new_key] = value
|
|
87
|
-
return new_dict
|
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
-
|
|
3
|
-
import math
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
from omegaconf import DictConfig
|
|
7
|
-
from peft import LoraConfig, get_peft_model
|
|
8
|
-
from peft.utils import prepare_model_for_kbit_training
|
|
9
|
-
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def cosine_annealing(
|
|
13
|
-
current_round: int,
|
|
14
|
-
total_round: int,
|
|
15
|
-
lrate_max: float = 0.001,
|
|
16
|
-
lrate_min: float = 0.0,
|
|
17
|
-
) -> float:
|
|
18
|
-
"""Implement cosine annealing learning rate schedule."""
|
|
19
|
-
cos_inner = math.pi * current_round / total_round
|
|
20
|
-
return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner))
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def get_model(model_cfg: DictConfig):
|
|
24
|
-
"""Load model with appropriate quantization config and other optimizations.
|
|
25
|
-
"""
|
|
26
|
-
if model_cfg.quantization == 4:
|
|
27
|
-
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
|
28
|
-
elif model_cfg.quantization == 8:
|
|
29
|
-
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
30
|
-
else:
|
|
31
|
-
raise ValueError(
|
|
32
|
-
f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/"
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
36
|
-
model_cfg.name,
|
|
37
|
-
quantization_config=quantization_config,
|
|
38
|
-
torch_dtype=torch.bfloat16,
|
|
39
|
-
low_cpu_mem_usage=True,
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
model = prepare_model_for_kbit_training(
|
|
43
|
-
model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
peft_config = LoraConfig(
|
|
47
|
-
r=model_cfg.lora.peft_lora_r,
|
|
48
|
-
lora_alpha=model_cfg.lora.peft_lora_alpha,
|
|
49
|
-
lora_dropout=0.075,
|
|
50
|
-
task_type="CAUSAL_LM",
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
if model_cfg.gradient_checkpointing:
|
|
54
|
-
model.config.use_cache = False
|
|
55
|
-
|
|
56
|
-
return get_peft_model(model, peft_config)
|
|
@@ -1,73 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
from datetime import datetime
|
|
5
|
-
|
|
6
|
-
from flwr.app import ArrayRecord, ConfigRecord, Context, MetricRecord
|
|
7
|
-
from flwr.common.config import unflatten_dict
|
|
8
|
-
from flwr.serverapp import Grid, ServerApp
|
|
9
|
-
from omegaconf import DictConfig
|
|
10
|
-
from peft import get_peft_model_state_dict, set_peft_model_state_dict
|
|
11
|
-
|
|
12
|
-
from $import_name.dataset import replace_keys
|
|
13
|
-
from $import_name.models import get_model
|
|
14
|
-
from $import_name.strategy import FlowerTuneLlm
|
|
15
|
-
|
|
16
|
-
# Create ServerApp
|
|
17
|
-
app = ServerApp()
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@app.main()
|
|
21
|
-
def main(grid: Grid, context: Context) -> None:
|
|
22
|
-
"""Main entry point for the ServerApp."""
|
|
23
|
-
# Create output directory given current timestamp
|
|
24
|
-
current_time = datetime.now()
|
|
25
|
-
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
|
|
26
|
-
save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
|
|
27
|
-
os.makedirs(save_path, exist_ok=True)
|
|
28
|
-
|
|
29
|
-
# Read from config
|
|
30
|
-
num_rounds = context.run_config["num-server-rounds"]
|
|
31
|
-
cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
|
|
32
|
-
|
|
33
|
-
# Get initial model weights
|
|
34
|
-
init_model = get_model(cfg.model)
|
|
35
|
-
arrays = ArrayRecord(get_peft_model_state_dict(init_model))
|
|
36
|
-
|
|
37
|
-
# Define strategy
|
|
38
|
-
strategy = FlowerTuneLlm(
|
|
39
|
-
fraction_train=cfg.strategy.fraction_train,
|
|
40
|
-
fraction_evaluate=cfg.strategy.fraction_evaluate,
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
# Start strategy, run FedAvg for `num_rounds`
|
|
44
|
-
strategy.start(
|
|
45
|
-
grid=grid,
|
|
46
|
-
initial_arrays=arrays,
|
|
47
|
-
train_config=ConfigRecord({"save_path": save_path}),
|
|
48
|
-
num_rounds=num_rounds,
|
|
49
|
-
evaluate_fn=get_evaluate_fn(
|
|
50
|
-
cfg.model, cfg.train.save_every_round, num_rounds, save_path
|
|
51
|
-
),
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
# Get function that will be executed by the strategy
|
|
56
|
-
# Here we use it to save global model checkpoints
|
|
57
|
-
def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
|
|
58
|
-
"""Return an evaluation function for saving global model."""
|
|
59
|
-
|
|
60
|
-
def evaluate(server_round: int, arrays: ArrayRecord) -> MetricRecord:
|
|
61
|
-
# Save model
|
|
62
|
-
if server_round != 0 and (
|
|
63
|
-
server_round == total_round or server_round % save_every_round == 0
|
|
64
|
-
):
|
|
65
|
-
# Init model
|
|
66
|
-
model = get_model(model_cfg)
|
|
67
|
-
set_peft_model_state_dict(model, arrays.to_torch_state_dict())
|
|
68
|
-
|
|
69
|
-
model.save_pretrained(f"{save_path}/peft_{server_round}")
|
|
70
|
-
|
|
71
|
-
return MetricRecord()
|
|
72
|
-
|
|
73
|
-
return evaluate
|
|
@@ -1,78 +0,0 @@
|
|
|
1
|
-
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
-
|
|
3
|
-
from collections.abc import Iterable
|
|
4
|
-
from logging import INFO, WARN
|
|
5
|
-
from typing import Optional
|
|
6
|
-
|
|
7
|
-
from flwr.app import ArrayRecord, ConfigRecord, Message, MetricRecord
|
|
8
|
-
from flwr.common import log
|
|
9
|
-
from flwr.serverapp import Grid
|
|
10
|
-
from flwr.serverapp.strategy import FedAvg
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class FlowerTuneLlm(FedAvg):
|
|
14
|
-
"""Customised FedAvg strategy implementation.
|
|
15
|
-
|
|
16
|
-
This class behaves just like FedAvg but also tracks the communication
|
|
17
|
-
costs associated with `train` over FL rounds.
|
|
18
|
-
"""
|
|
19
|
-
def __init__(self, **kwargs):
|
|
20
|
-
super().__init__(**kwargs)
|
|
21
|
-
self.comm_tracker = CommunicationTracker()
|
|
22
|
-
|
|
23
|
-
def configure_train(
|
|
24
|
-
self, server_round: int, arrays: ArrayRecord, config: ConfigRecord, grid: Grid
|
|
25
|
-
) -> Iterable[Message]:
|
|
26
|
-
"""Configure the next round of training."""
|
|
27
|
-
messages = super().configure_train(server_round, arrays, config, grid)
|
|
28
|
-
|
|
29
|
-
# Track communication costs
|
|
30
|
-
self.comm_tracker.track(messages)
|
|
31
|
-
|
|
32
|
-
return messages
|
|
33
|
-
|
|
34
|
-
def aggregate_train(
|
|
35
|
-
self,
|
|
36
|
-
server_round: int,
|
|
37
|
-
replies: Iterable[Message],
|
|
38
|
-
) -> tuple[Optional[ArrayRecord], Optional[MetricRecord]]:
|
|
39
|
-
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
40
|
-
# Track communication costs
|
|
41
|
-
self.comm_tracker.track(replies)
|
|
42
|
-
|
|
43
|
-
arrays, metrics = super().aggregate_train(server_round, replies)
|
|
44
|
-
|
|
45
|
-
return arrays, metrics
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class CommunicationTracker:
|
|
49
|
-
"""Communication costs tracker over FL rounds."""
|
|
50
|
-
def __init__(self):
|
|
51
|
-
self.curr_comm_cost = 0.0
|
|
52
|
-
|
|
53
|
-
def track(self, messages: Iterable[Message]):
|
|
54
|
-
comm_cost = (
|
|
55
|
-
sum(
|
|
56
|
-
record.count_bytes()
|
|
57
|
-
for msg in messages
|
|
58
|
-
if msg.has_content()
|
|
59
|
-
for record in msg.content.array_records.values()
|
|
60
|
-
)
|
|
61
|
-
/ 1024**2
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
self.curr_comm_cost += comm_cost
|
|
65
|
-
log(
|
|
66
|
-
INFO,
|
|
67
|
-
"Communication budget: used %.2f MB (+%.2f MB this round) / 200,000 MB",
|
|
68
|
-
self.curr_comm_cost,
|
|
69
|
-
comm_cost,
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
if self.curr_comm_cost > 2e5:
|
|
73
|
-
log(
|
|
74
|
-
WARN,
|
|
75
|
-
"The accumulated communication cost has exceeded 200,000 MB. "
|
|
76
|
-
"Please consider reducing it if you plan to participate "
|
|
77
|
-
"FlowerTune LLM Leaderboard.",
|
|
78
|
-
)
|