flwr-nightly 1.10.0.dev20240612__py3-none-any.whl → 1.10.0.dev20240624__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +3 -0
- flwr/cli/build.py +6 -8
- flwr/cli/config_utils.py +53 -3
- flwr/cli/install.py +35 -20
- flwr/cli/new/new.py +104 -28
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +86 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +124 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +42 -0
- flwr/cli/run/run.py +46 -2
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +22 -10
- flwr/client/client_app.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +94 -0
- flwr/client/grpc_client/connection.py +5 -1
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/connection.py +9 -2
- flwr/client/grpc_rere_client/grpc_adapter.py +133 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/__init__.py +4 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +10 -2
- flwr/client/supernode/app.py +141 -41
- flwr/common/__init__.py +12 -12
- flwr/common/address.py +1 -1
- flwr/common/config.py +73 -0
- flwr/common/constant.py +16 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +1 -1
- flwr/common/object_ref.py +39 -5
- flwr/common/record/__init__.py +1 -1
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -1
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/telemetry.py +4 -0
- flwr/common/typing.py +9 -0
- flwr/common/version.py +14 -0
- flwr/proto/exec_pb2.py +34 -0
- flwr/proto/exec_pb2.pyi +55 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/server/__init__.py +2 -2
- flwr/server/app.py +62 -25
- flwr/server/compat/app.py +1 -1
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +1 -1
- flwr/server/driver/driver.py +6 -0
- flwr/server/driver/grpc_driver.py +85 -63
- flwr/server/driver/inmemory_driver.py +28 -26
- flwr/server/run_serverapp.py +65 -20
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +15 -3
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -1
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +4 -4
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +44 -25
- flwr/server/superlink/fleet/vce/vce_api.py +3 -1
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +9 -6
- flwr/server/superlink/state/sqlite_state.py +7 -4
- flwr/server/superlink/state/state.py +6 -5
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/simulation/__init__.py +5 -2
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +0 -6
- flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
- flwr/simulation/run_simulation.py +63 -22
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +178 -0
- flwr/superexec/exec_grpc.py +51 -0
- flwr/superexec/exec_servicer.py +65 -0
- flwr/superexec/executor.py +54 -0
- {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/METADATA +2 -1
- {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/RECORD +130 -101
- {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240624.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
from typing import Callable, Dict, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from omegaconf import DictConfig
|
|
8
|
+
from peft import get_peft_model_state_dict, set_peft_model_state_dict
|
|
9
|
+
from transformers import TrainingArguments
|
|
10
|
+
from trl import SFTTrainer
|
|
11
|
+
|
|
12
|
+
from flwr.client import NumPyClient
|
|
13
|
+
from flwr.common.typing import NDArrays, Scalar
|
|
14
|
+
from $import_name.dataset import reformat
|
|
15
|
+
from $import_name.models import cosine_annealing, get_model
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# pylint: disable=too-many-arguments
|
|
19
|
+
# pylint: disable=too-many-instance-attributes
|
|
20
|
+
class FlowerClient(NumPyClient):
|
|
21
|
+
"""Standard Flower client for CNN training."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
model_cfg: DictConfig,
|
|
26
|
+
train_cfg: DictConfig,
|
|
27
|
+
trainset,
|
|
28
|
+
tokenizer,
|
|
29
|
+
formatting_prompts_func,
|
|
30
|
+
data_collator,
|
|
31
|
+
save_path,
|
|
32
|
+
): # pylint: disable=too-many-arguments
|
|
33
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
34
|
+
self.train_cfg = train_cfg
|
|
35
|
+
self.training_argumnets = TrainingArguments(**train_cfg.training_arguments)
|
|
36
|
+
self.tokenizer = tokenizer
|
|
37
|
+
self.formatting_prompts_func = formatting_prompts_func
|
|
38
|
+
self.data_collator = data_collator
|
|
39
|
+
self.save_path = save_path
|
|
40
|
+
|
|
41
|
+
# instantiate model
|
|
42
|
+
self.model = get_model(model_cfg)
|
|
43
|
+
|
|
44
|
+
self.trainset = trainset
|
|
45
|
+
|
|
46
|
+
def fit(
|
|
47
|
+
self, parameters: NDArrays, config: Dict[str, Scalar]
|
|
48
|
+
) -> Tuple[NDArrays, int, Dict]:
|
|
49
|
+
"""Implement distributed fit function for a given client."""
|
|
50
|
+
set_parameters(self.model, parameters)
|
|
51
|
+
|
|
52
|
+
new_lr = cosine_annealing(
|
|
53
|
+
int(config["current_round"]),
|
|
54
|
+
self.train_cfg.num_rounds,
|
|
55
|
+
self.train_cfg.learning_rate_max,
|
|
56
|
+
self.train_cfg.learning_rate_min,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
self.training_argumnets.learning_rate = new_lr
|
|
60
|
+
self.training_argumnets.output_dir = self.save_path
|
|
61
|
+
|
|
62
|
+
# Construct trainer
|
|
63
|
+
trainer = SFTTrainer(
|
|
64
|
+
model=self.model,
|
|
65
|
+
tokenizer=self.tokenizer,
|
|
66
|
+
args=self.training_argumnets,
|
|
67
|
+
max_seq_length=self.train_cfg.seq_length,
|
|
68
|
+
train_dataset=self.trainset,
|
|
69
|
+
formatting_func=self.formatting_prompts_func,
|
|
70
|
+
data_collator=self.data_collator,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Do local training
|
|
74
|
+
results = trainer.train()
|
|
75
|
+
|
|
76
|
+
return (
|
|
77
|
+
get_parameters(self.model),
|
|
78
|
+
len(self.trainset),
|
|
79
|
+
{"train_loss": results.training_loss},
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def set_parameters(model, parameters: NDArrays) -> None:
|
|
84
|
+
"""Change the parameters of the model using the given ones."""
|
|
85
|
+
peft_state_dict_keys = get_peft_model_state_dict(model).keys()
|
|
86
|
+
params_dict = zip(peft_state_dict_keys, parameters)
|
|
87
|
+
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
|
|
88
|
+
set_peft_model_state_dict(model, state_dict)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_parameters(model) -> NDArrays:
|
|
92
|
+
"""Return the parameters of the current net."""
|
|
93
|
+
state_dict = get_peft_model_state_dict(model)
|
|
94
|
+
return [val.cpu().numpy() for _, val in state_dict.items()]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def gen_client_fn(
|
|
98
|
+
fds,
|
|
99
|
+
tokenizer,
|
|
100
|
+
formatting_prompts_func,
|
|
101
|
+
data_collator,
|
|
102
|
+
model_cfg: DictConfig,
|
|
103
|
+
train_cfg: DictConfig,
|
|
104
|
+
save_path: str,
|
|
105
|
+
) -> Callable[[str], FlowerClient]: # pylint: disable=too-many-arguments
|
|
106
|
+
"""Generate the client function that creates the Flower Clients."""
|
|
107
|
+
|
|
108
|
+
def client_fn(cid: str) -> FlowerClient:
|
|
109
|
+
"""Create a Flower client representing a single organization."""
|
|
110
|
+
# Let's get the partition corresponding to the i-th client
|
|
111
|
+
client_trainset = fds.load_partition(int(cid), "train")
|
|
112
|
+
client_trainset = reformat(client_trainset, llm_task="$llm_challenge_str")
|
|
113
|
+
|
|
114
|
+
return FlowerClient(
|
|
115
|
+
model_cfg,
|
|
116
|
+
train_cfg,
|
|
117
|
+
client_trainset,
|
|
118
|
+
tokenizer,
|
|
119
|
+
formatting_prompts_func,
|
|
120
|
+
data_collator,
|
|
121
|
+
save_path,
|
|
122
|
+
).to_client()
|
|
123
|
+
|
|
124
|
+
return client_fn
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# Federated Instruction Tuning
|
|
2
|
+
---
|
|
3
|
+
model:
|
|
4
|
+
name: "mistralai/Mistral-7B-v0.3"
|
|
5
|
+
quantization: 4 # 8 or 4 if you want to do quantization with BitsAndBytes
|
|
6
|
+
gradient_checkpointing: True
|
|
7
|
+
lora:
|
|
8
|
+
peft_lora_r: 32
|
|
9
|
+
peft_lora_alpha: 64
|
|
10
|
+
|
|
11
|
+
train:
|
|
12
|
+
num_rounds: null
|
|
13
|
+
save_every_round: 5
|
|
14
|
+
learning_rate_max: 5e-5
|
|
15
|
+
learning_rate_min: 1e-6
|
|
16
|
+
seq_length: 512
|
|
17
|
+
training_arguments:
|
|
18
|
+
output_dir: null # to be set by hydra
|
|
19
|
+
learning_rate: null # to be set by the client
|
|
20
|
+
per_device_train_batch_size: 16
|
|
21
|
+
gradient_accumulation_steps: 1
|
|
22
|
+
logging_steps: 10
|
|
23
|
+
num_train_epochs: 3
|
|
24
|
+
max_steps: 10
|
|
25
|
+
report_to: null
|
|
26
|
+
save_steps: 1000
|
|
27
|
+
save_total_limit: 10
|
|
28
|
+
gradient_checkpointing: True
|
|
29
|
+
lr_scheduler_type: "constant"
|
|
30
|
+
|
|
31
|
+
strategy:
|
|
32
|
+
_target_: flwr.server.strategy.FedAvg
|
|
33
|
+
fraction_fit: $fraction_fit
|
|
34
|
+
fraction_evaluate: 0.0 # no client evaluation
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
from transformers import AutoTokenizer
|
|
4
|
+
from trl import DataCollatorForCompletionOnlyLM
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def formatting_prompts_func(example):
|
|
8
|
+
"""Construct prompts."""
|
|
9
|
+
output_texts = []
|
|
10
|
+
# Constructing a standard Alpaca
|
|
11
|
+
# (https://github.com/tatsu-lab/stanford_alpaca#data-release) prompt
|
|
12
|
+
mssg = (
|
|
13
|
+
"Below is an instruction that describes a task. "
|
|
14
|
+
"Write a response that appropriately completes the request."
|
|
15
|
+
)
|
|
16
|
+
for i in range(len(example["instruction"])):
|
|
17
|
+
text = (
|
|
18
|
+
f"{mssg}\n### Instruction:\n{example['instruction'][i]}\n"
|
|
19
|
+
f"### Response: {example['response'][i]}"
|
|
20
|
+
)
|
|
21
|
+
output_texts.append(text)
|
|
22
|
+
return output_texts
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_tokenizer_and_data_collator_and_propt_formatting(model_name: str):
|
|
26
|
+
"""Get tokenizer, data_collator and prompt formatting."""
|
|
27
|
+
# From: https://huggingface.co/docs/trl/en/sft_trainer
|
|
28
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
29
|
+
model_name, use_fast=True, padding_side="right"
|
|
30
|
+
)
|
|
31
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
32
|
+
response_template_with_context = "\n### Response:" # alpaca response tag
|
|
33
|
+
response_template_ids = tokenizer.encode(
|
|
34
|
+
response_template_with_context, add_special_tokens=False
|
|
35
|
+
)[2:]
|
|
36
|
+
data_collator = DataCollatorForCompletionOnlyLM(
|
|
37
|
+
response_template_ids, tokenizer=tokenizer
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
return tokenizer, data_collator, formatting_prompts_func
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def formatting(dataset):
|
|
44
|
+
"""Format dataset."""
|
|
45
|
+
dataset["instruction"] = dataset["instruction"] + " " + dataset["input"]
|
|
46
|
+
return dataset
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def reformat(dataset, llm_task):
|
|
50
|
+
"""Reformat datasets."""
|
|
51
|
+
dataset = dataset.rename_column("output", "response")
|
|
52
|
+
if llm_task == "finance" or llm_task == "code":
|
|
53
|
+
dataset = dataset.map(formatting, remove_columns=["input"])
|
|
54
|
+
if llm_task == "medical":
|
|
55
|
+
dataset = dataset.remove_columns(["instruction"])
|
|
56
|
+
dataset = dataset.rename_column("input", "instruction")
|
|
57
|
+
return dataset
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
from peft import LoraConfig, get_peft_model
|
|
8
|
+
from peft.utils import prepare_model_for_kbit_training
|
|
9
|
+
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def cosine_annealing(
|
|
13
|
+
current_round: int,
|
|
14
|
+
total_round: int,
|
|
15
|
+
lrate_max: float = 0.001,
|
|
16
|
+
lrate_min: float = 0.0,
|
|
17
|
+
) -> float:
|
|
18
|
+
"""Implement cosine annealing learning rate schedule."""
|
|
19
|
+
cos_inner = math.pi * current_round / total_round
|
|
20
|
+
return lrate_min + 0.5 * (lrate_max - lrate_min) * (1 + math.cos(cos_inner))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_model(model_cfg: DictConfig):
|
|
24
|
+
"""Load model with appropriate quantization config and other optimizations.
|
|
25
|
+
|
|
26
|
+
Please refer to this example for `peft + BitsAndBytes`:
|
|
27
|
+
https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py
|
|
28
|
+
"""
|
|
29
|
+
if model_cfg.quantization == 4:
|
|
30
|
+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
|
31
|
+
elif model_cfg.quantization == 8:
|
|
32
|
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Use 4-bit or 8-bit quantization. You passed: {model_cfg.quantization}/"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
39
|
+
model_cfg.name,
|
|
40
|
+
quantization_config=quantization_config,
|
|
41
|
+
torch_dtype=torch.bfloat16,
|
|
42
|
+
low_cpu_mem_usage=True,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
model = prepare_model_for_kbit_training(
|
|
46
|
+
model, use_gradient_checkpointing=model_cfg.gradient_checkpointing
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
peft_config = LoraConfig(
|
|
50
|
+
r=model_cfg.lora.peft_lora_r,
|
|
51
|
+
lora_alpha=model_cfg.lora.peft_lora_alpha,
|
|
52
|
+
lora_dropout=0.075,
|
|
53
|
+
task_type="CAUSAL_LM",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if model_cfg.gradient_checkpointing:
|
|
57
|
+
model.config.use_cache = False
|
|
58
|
+
|
|
59
|
+
return get_peft_model(model, peft_config)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
from $import_name.client import set_parameters
|
|
4
|
+
from $import_name.models import get_model
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# Get function that will be executed by the strategy's evaluate() method
|
|
8
|
+
# Here we use it to save global model checkpoints
|
|
9
|
+
def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
|
|
10
|
+
"""Return an evaluation function for saving global model."""
|
|
11
|
+
|
|
12
|
+
def evaluate(server_round: int, parameters, config):
|
|
13
|
+
# Save model
|
|
14
|
+
if server_round != 0 and (
|
|
15
|
+
server_round == total_round or server_round % save_every_round == 0
|
|
16
|
+
):
|
|
17
|
+
# Init model
|
|
18
|
+
model = get_model(model_cfg)
|
|
19
|
+
set_parameters(model, parameters)
|
|
20
|
+
|
|
21
|
+
model.save_pretrained(f"{save_path}/peft_{server_round}")
|
|
22
|
+
|
|
23
|
+
return 0.0, {}
|
|
24
|
+
|
|
25
|
+
return evaluate
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_on_fit_config():
|
|
29
|
+
"""
|
|
30
|
+
Return a function that will be used to construct the config
|
|
31
|
+
that the client's fit() method will receive.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def fit_config_fn(server_round: int):
|
|
35
|
+
fit_config = {"current_round": server_round}
|
|
36
|
+
return fit_config
|
|
37
|
+
|
|
38
|
+
return fit_config_fn
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def fit_weighted_average(metrics):
|
|
42
|
+
"""Aggregate (federated) evaluation metrics."""
|
|
43
|
+
# Multiply accuracy of each client by number of examples used
|
|
44
|
+
losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
|
|
45
|
+
examples = [num_examples for num_examples, _ in metrics]
|
|
46
|
+
|
|
47
|
+
# Aggregate and return custom metric (weighted average)
|
|
48
|
+
return {"train_loss": sum(losses) / sum(examples)}
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Federated Instruction Tuning (static)
|
|
2
|
+
---
|
|
3
|
+
dataset:
|
|
4
|
+
name: $dataset_name
|
|
5
|
+
|
|
6
|
+
# FL experimental settings
|
|
7
|
+
num_clients: $num_clients # total number of clients
|
|
8
|
+
num_rounds: 200
|
|
9
|
+
partitioner:
|
|
10
|
+
_target_: flwr_datasets.partitioner.IidPartitioner
|
|
11
|
+
num_partitions: $num_clients
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "$package_name"
|
|
7
|
+
version = "1.0.0"
|
|
8
|
+
description = ""
|
|
9
|
+
authors = [
|
|
10
|
+
{ name = "The Flower Authors", email = "hello@flower.ai" },
|
|
11
|
+
]
|
|
12
|
+
license = { text = "Apache License (2.0)" }
|
|
13
|
+
dependencies = [
|
|
14
|
+
"flwr[simulation]>=1.9.0,<2.0",
|
|
15
|
+
"flwr-datasets>=0.1.0,<1.0.0",
|
|
16
|
+
"hydra-core==1.3.2",
|
|
17
|
+
"trl==0.8.1",
|
|
18
|
+
"bitsandbytes==0.43.0",
|
|
19
|
+
"scipy==1.13.0",
|
|
20
|
+
"peft==0.6.2",
|
|
21
|
+
"transformers==4.39.3",
|
|
22
|
+
"sentencepiece==0.2.0",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
[tool.hatch.build.targets.wheel]
|
|
26
|
+
packages = ["."]
|
|
27
|
+
|
|
28
|
+
[flower]
|
|
29
|
+
publisher = "$username"
|
|
30
|
+
|
|
31
|
+
[flower.components]
|
|
32
|
+
serverapp = "$import_name.app:server"
|
|
33
|
+
clientapp = "$import_name.app:client"
|
|
34
|
+
|
|
35
|
+
[flower.engine]
|
|
36
|
+
name = "simulation"
|
|
37
|
+
|
|
38
|
+
[flower.engine.simulation.supernode]
|
|
39
|
+
num = $num_clients
|
|
40
|
+
|
|
41
|
+
[flower.engine.simulation]
|
|
42
|
+
backend_config = { client_resources = { num_cpus = 8, num_gpus = 1.0 } }
|
flwr/cli/run/run.py
CHANGED
|
@@ -16,12 +16,18 @@
|
|
|
16
16
|
|
|
17
17
|
import sys
|
|
18
18
|
from enum import Enum
|
|
19
|
+
from logging import DEBUG
|
|
19
20
|
from typing import Optional
|
|
20
21
|
|
|
21
22
|
import typer
|
|
22
23
|
from typing_extensions import Annotated
|
|
23
24
|
|
|
24
25
|
from flwr.cli import config_utils
|
|
26
|
+
from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
|
|
27
|
+
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
|
|
28
|
+
from flwr.common.logger import log
|
|
29
|
+
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
|
|
30
|
+
from flwr.proto.exec_pb2_grpc import ExecStub
|
|
25
31
|
from flwr.simulation.run_simulation import _run_simulation
|
|
26
32
|
|
|
27
33
|
|
|
@@ -31,20 +37,35 @@ class Engine(str, Enum):
|
|
|
31
37
|
SIMULATION = "simulation"
|
|
32
38
|
|
|
33
39
|
|
|
40
|
+
# pylint: disable-next=too-many-locals
|
|
34
41
|
def run(
|
|
35
42
|
engine: Annotated[
|
|
36
43
|
Optional[Engine],
|
|
37
|
-
typer.Option(
|
|
44
|
+
typer.Option(
|
|
45
|
+
case_sensitive=False,
|
|
46
|
+
help="The engine to run FL with (currently only simulation is supported).",
|
|
47
|
+
),
|
|
38
48
|
] = None,
|
|
49
|
+
use_superexec: Annotated[
|
|
50
|
+
bool,
|
|
51
|
+
typer.Option(
|
|
52
|
+
case_sensitive=False, help="Use this flag to use the new SuperExec API"
|
|
53
|
+
),
|
|
54
|
+
] = False,
|
|
39
55
|
) -> None:
|
|
40
56
|
"""Run Flower project."""
|
|
57
|
+
if use_superexec:
|
|
58
|
+
_start_superexec_run()
|
|
59
|
+
return
|
|
60
|
+
|
|
41
61
|
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
|
|
42
62
|
|
|
43
63
|
config, errors, warnings = config_utils.load_and_validate()
|
|
44
64
|
|
|
45
65
|
if config is None:
|
|
46
66
|
typer.secho(
|
|
47
|
-
"Project configuration could not be loaded.\
|
|
67
|
+
"Project configuration could not be loaded.\n"
|
|
68
|
+
"pyproject.toml is invalid:\n"
|
|
48
69
|
+ "\n".join([f"- {line}" for line in errors]),
|
|
49
70
|
fg=typer.colors.RED,
|
|
50
71
|
bold=True,
|
|
@@ -69,12 +90,16 @@ def run(
|
|
|
69
90
|
|
|
70
91
|
if engine == Engine.SIMULATION:
|
|
71
92
|
num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]
|
|
93
|
+
backend_config = config["flower"]["engine"]["simulation"].get(
|
|
94
|
+
"backend_config", None
|
|
95
|
+
)
|
|
72
96
|
|
|
73
97
|
typer.secho("Starting run... ", fg=typer.colors.BLUE)
|
|
74
98
|
_run_simulation(
|
|
75
99
|
server_app_attr=server_app_ref,
|
|
76
100
|
client_app_attr=client_app_ref,
|
|
77
101
|
num_supernodes=num_supernodes,
|
|
102
|
+
backend_config=backend_config,
|
|
78
103
|
)
|
|
79
104
|
else:
|
|
80
105
|
typer.secho(
|
|
@@ -82,3 +107,22 @@ def run(
|
|
|
82
107
|
fg=typer.colors.RED,
|
|
83
108
|
bold=True,
|
|
84
109
|
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _start_superexec_run() -> None:
|
|
113
|
+
def on_channel_state_change(channel_connectivity: str) -> None:
|
|
114
|
+
"""Log channel connectivity."""
|
|
115
|
+
log(DEBUG, channel_connectivity)
|
|
116
|
+
|
|
117
|
+
channel = create_channel(
|
|
118
|
+
server_address=SUPEREXEC_DEFAULT_ADDRESS,
|
|
119
|
+
insecure=True,
|
|
120
|
+
root_certificates=None,
|
|
121
|
+
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
122
|
+
interceptors=None,
|
|
123
|
+
)
|
|
124
|
+
channel.subscribe(on_channel_state_change)
|
|
125
|
+
stub = ExecStub(channel)
|
|
126
|
+
|
|
127
|
+
req = StartRunRequest()
|
|
128
|
+
stub.StartRun(req)
|
flwr/client/__init__.py
CHANGED
flwr/client/app.py
CHANGED
|
@@ -19,7 +19,7 @@ import sys
|
|
|
19
19
|
import time
|
|
20
20
|
from dataclasses import dataclass
|
|
21
21
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
22
|
-
from typing import Callable, ContextManager, Optional, Tuple, Type, Union
|
|
22
|
+
from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union
|
|
23
23
|
|
|
24
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
25
25
|
from grpc import RpcError
|
|
@@ -31,6 +31,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event
|
|
|
31
31
|
from flwr.common.address import parse_address
|
|
32
32
|
from flwr.common.constant import (
|
|
33
33
|
MISSING_EXTRA_REST,
|
|
34
|
+
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
34
35
|
TRANSPORT_TYPE_GRPC_BIDI,
|
|
35
36
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
36
37
|
TRANSPORT_TYPE_REST,
|
|
@@ -41,6 +42,7 @@ from flwr.common.logger import log, warn_deprecated_feature
|
|
|
41
42
|
from flwr.common.message import Error
|
|
42
43
|
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
|
|
43
44
|
|
|
45
|
+
from .grpc_adapter_client.connection import grpc_adapter
|
|
44
46
|
from .grpc_client.connection import grpc_connection
|
|
45
47
|
from .grpc_rere_client.connection import grpc_request_response
|
|
46
48
|
from .message_handler.message_handler import handle_control_message
|
|
@@ -177,7 +179,7 @@ def start_client(
|
|
|
177
179
|
def _start_client_internal(
|
|
178
180
|
*,
|
|
179
181
|
server_address: str,
|
|
180
|
-
load_client_app_fn: Optional[Callable[[], ClientApp]] = None,
|
|
182
|
+
load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
|
|
181
183
|
client_fn: Optional[ClientFn] = None,
|
|
182
184
|
client: Optional[Client] = None,
|
|
183
185
|
grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
@@ -252,7 +254,7 @@ def _start_client_internal(
|
|
|
252
254
|
|
|
253
255
|
client_fn = single_client_factory
|
|
254
256
|
|
|
255
|
-
def _load_client_app() -> ClientApp:
|
|
257
|
+
def _load_client_app(_1: str, _2: str) -> ClientApp:
|
|
256
258
|
return ClientApp(client_fn=client_fn)
|
|
257
259
|
|
|
258
260
|
load_client_app_fn = _load_client_app
|
|
@@ -308,6 +310,8 @@ def _start_client_internal(
|
|
|
308
310
|
)
|
|
309
311
|
|
|
310
312
|
node_state = NodeState()
|
|
313
|
+
# run_id -> (fab_id, fab_version)
|
|
314
|
+
run_info: Dict[int, Tuple[str, str]] = {}
|
|
311
315
|
|
|
312
316
|
while not app_state_tracker.interrupt:
|
|
313
317
|
sleep_duration: int = 0
|
|
@@ -319,7 +323,6 @@ def _start_client_internal(
|
|
|
319
323
|
root_certificates,
|
|
320
324
|
authentication_keys,
|
|
321
325
|
) as conn:
|
|
322
|
-
# pylint: disable-next=W0612
|
|
323
326
|
receive, send, create_node, delete_node, get_run = conn
|
|
324
327
|
|
|
325
328
|
# Register node
|
|
@@ -356,13 +359,20 @@ def _start_client_internal(
|
|
|
356
359
|
send(out_message)
|
|
357
360
|
break
|
|
358
361
|
|
|
362
|
+
# Get run info
|
|
363
|
+
run_id = message.metadata.run_id
|
|
364
|
+
if run_id not in run_info:
|
|
365
|
+
if get_run is not None:
|
|
366
|
+
run_info[run_id] = get_run(run_id)
|
|
367
|
+
# If get_run is None, i.e., in grpc-bidi mode
|
|
368
|
+
else:
|
|
369
|
+
run_info[run_id] = ("", "")
|
|
370
|
+
|
|
359
371
|
# Register context for this run
|
|
360
|
-
node_state.register_context(run_id=
|
|
372
|
+
node_state.register_context(run_id=run_id)
|
|
361
373
|
|
|
362
374
|
# Retrieve context for this run
|
|
363
|
-
context = node_state.retrieve_context(
|
|
364
|
-
run_id=message.metadata.run_id
|
|
365
|
-
)
|
|
375
|
+
context = node_state.retrieve_context(run_id=run_id)
|
|
366
376
|
|
|
367
377
|
# Create an error reply message that will never be used to prevent
|
|
368
378
|
# the used-before-assignment linting error
|
|
@@ -373,7 +383,7 @@ def _start_client_internal(
|
|
|
373
383
|
# Handle app loading and task message
|
|
374
384
|
try:
|
|
375
385
|
# Load ClientApp instance
|
|
376
|
-
client_app: ClientApp = load_client_app_fn()
|
|
386
|
+
client_app: ClientApp = load_client_app_fn(*run_info[run_id])
|
|
377
387
|
|
|
378
388
|
# Execute ClientApp
|
|
379
389
|
reply_message = client_app(message=message, context=context)
|
|
@@ -411,7 +421,7 @@ def _start_client_internal(
|
|
|
411
421
|
else:
|
|
412
422
|
# No exception, update node state
|
|
413
423
|
node_state.update_context(
|
|
414
|
-
run_id=
|
|
424
|
+
run_id=run_id,
|
|
415
425
|
context=context,
|
|
416
426
|
)
|
|
417
427
|
|
|
@@ -592,6 +602,8 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
|
592
602
|
connection, error_type = http_request_response, RequestsConnectionError
|
|
593
603
|
elif transport == TRANSPORT_TYPE_GRPC_RERE:
|
|
594
604
|
connection, error_type = grpc_request_response, RpcError
|
|
605
|
+
elif transport == TRANSPORT_TYPE_GRPC_ADAPTER:
|
|
606
|
+
connection, error_type = grpc_adapter, RpcError
|
|
595
607
|
elif transport == TRANSPORT_TYPE_GRPC_BIDI:
|
|
596
608
|
connection, error_type = grpc_connection, RpcError
|
|
597
609
|
else:
|
flwr/client/client_app.py
CHANGED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Client-side part of the GrpcAdapter transport layer."""
|