flwr-nightly 1.11.0.dev20240823__py3-none-any.whl → 1.12.0.dev20240906__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +0 -2
- flwr/cli/new/new.py +24 -10
- flwr/cli/new/templates/app/LICENSE.tpl +202 -0
- flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +16 -6
- flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
- flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
- flwr/cli/new/templates/app/code/flwr_tune/{client.py.tpl → client_app.py.tpl} +50 -40
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +32 -2
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -3
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +95 -0
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
- flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +34 -7
- flwr/cli/run/run.py +2 -2
- flwr/client/__init__.py +0 -4
- flwr/client/grpc_rere_client/client_interceptor.py +13 -4
- flwr/client/supernode/app.py +3 -1
- flwr/common/config.py +14 -11
- flwr/common/telemetry.py +36 -30
- flwr/server/__init__.py +0 -4
- flwr/server/app.py +13 -13
- flwr/server/compat/app.py +0 -5
- flwr/server/driver/grpc_driver.py +1 -3
- flwr/server/run_serverapp.py +15 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +11 -11
- flwr/server/superlink/state/in_memory_state.py +15 -15
- flwr/server/superlink/state/sqlite_state.py +10 -10
- flwr/server/superlink/state/state.py +8 -8
- flwr/simulation/run_simulation.py +23 -6
- flwr/superexec/__init__.py +0 -6
- flwr/superexec/app.py +3 -1
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/METADATA +3 -3
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/RECORD +43 -35
- flwr_nightly-1.12.0.dev20240906.dist-info/entry_points.txt +10 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +0 -89
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +0 -34
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +0 -48
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +0 -11
- flwr_nightly-1.11.0.dev20240823.dist-info/entry_points.txt +0 -10
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.12.0.dev20240906.dist-info}/WHEEL +0 -0
|
@@ -1,20 +1,32 @@
|
|
|
1
1
|
"""$project_name: A Flower / FlowerTune app."""
|
|
2
2
|
|
|
3
|
+
import os
|
|
4
|
+
import warnings
|
|
3
5
|
from collections import OrderedDict
|
|
4
|
-
from typing import
|
|
6
|
+
from typing import Dict, Tuple
|
|
5
7
|
|
|
6
8
|
import torch
|
|
9
|
+
from flwr.client import ClientApp, NumPyClient
|
|
10
|
+
from flwr.common import Context
|
|
11
|
+
from flwr.common.config import unflatten_dict
|
|
12
|
+
from flwr.common.typing import NDArrays, Scalar
|
|
7
13
|
from omegaconf import DictConfig
|
|
8
14
|
from peft import get_peft_model_state_dict, set_peft_model_state_dict
|
|
9
15
|
from transformers import TrainingArguments
|
|
10
16
|
from trl import SFTTrainer
|
|
11
17
|
|
|
12
|
-
from
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
18
|
+
from $import_name.dataset import (
|
|
19
|
+
get_tokenizer_and_data_collator_and_propt_formatting,
|
|
20
|
+
load_data,
|
|
21
|
+
replace_keys,
|
|
22
|
+
)
|
|
16
23
|
from $import_name.models import cosine_annealing, get_model
|
|
17
24
|
|
|
25
|
+
# Avoid warnings
|
|
26
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
27
|
+
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
|
|
28
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
29
|
+
|
|
18
30
|
|
|
19
31
|
# pylint: disable=too-many-arguments
|
|
20
32
|
# pylint: disable=too-many-instance-attributes
|
|
@@ -29,7 +41,7 @@ class FlowerClient(NumPyClient):
|
|
|
29
41
|
tokenizer,
|
|
30
42
|
formatting_prompts_func,
|
|
31
43
|
data_collator,
|
|
32
|
-
|
|
44
|
+
num_rounds,
|
|
33
45
|
): # pylint: disable=too-many-arguments
|
|
34
46
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
35
47
|
self.train_cfg = train_cfg
|
|
@@ -37,13 +49,12 @@ class FlowerClient(NumPyClient):
|
|
|
37
49
|
self.tokenizer = tokenizer
|
|
38
50
|
self.formatting_prompts_func = formatting_prompts_func
|
|
39
51
|
self.data_collator = data_collator
|
|
40
|
-
self.
|
|
52
|
+
self.num_rounds = num_rounds
|
|
53
|
+
self.trainset = trainset
|
|
41
54
|
|
|
42
55
|
# instantiate model
|
|
43
56
|
self.model = get_model(model_cfg)
|
|
44
57
|
|
|
45
|
-
self.trainset = trainset
|
|
46
|
-
|
|
47
58
|
def fit(
|
|
48
59
|
self, parameters: NDArrays, config: Dict[str, Scalar]
|
|
49
60
|
) -> Tuple[NDArrays, int, Dict]:
|
|
@@ -52,13 +63,13 @@ class FlowerClient(NumPyClient):
|
|
|
52
63
|
|
|
53
64
|
new_lr = cosine_annealing(
|
|
54
65
|
int(config["current_round"]),
|
|
55
|
-
self.
|
|
66
|
+
self.num_rounds,
|
|
56
67
|
self.train_cfg.learning_rate_max,
|
|
57
68
|
self.train_cfg.learning_rate_min,
|
|
58
69
|
)
|
|
59
70
|
|
|
60
71
|
self.training_argumnets.learning_rate = new_lr
|
|
61
|
-
self.training_argumnets.output_dir =
|
|
72
|
+
self.training_argumnets.output_dir = config["save_path"]
|
|
62
73
|
|
|
63
74
|
# Construct trainer
|
|
64
75
|
trainer = SFTTrainer(
|
|
@@ -95,32 +106,31 @@ def get_parameters(model) -> NDArrays:
|
|
|
95
106
|
return [val.cpu().numpy() for _, val in state_dict.items()]
|
|
96
107
|
|
|
97
108
|
|
|
98
|
-
def
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
return client_fn
|
|
109
|
+
def client_fn(context: Context) -> FlowerClient:
|
|
110
|
+
"""Create a Flower client representing a single organization."""
|
|
111
|
+
partition_id = context.node_config["partition-id"]
|
|
112
|
+
num_partitions = context.node_config["num-partitions"]
|
|
113
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
114
|
+
cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
|
|
115
|
+
|
|
116
|
+
# Let's get the client partition
|
|
117
|
+
client_trainset = load_data(partition_id, num_partitions, cfg.static.dataset.name)
|
|
118
|
+
(
|
|
119
|
+
tokenizer,
|
|
120
|
+
data_collator,
|
|
121
|
+
formatting_prompts_func,
|
|
122
|
+
) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)
|
|
123
|
+
|
|
124
|
+
return FlowerClient(
|
|
125
|
+
cfg.model,
|
|
126
|
+
cfg.train,
|
|
127
|
+
client_trainset,
|
|
128
|
+
tokenizer,
|
|
129
|
+
formatting_prompts_func,
|
|
130
|
+
data_collator,
|
|
131
|
+
num_rounds,
|
|
132
|
+
).to_client()
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# Flower ClientApp
|
|
136
|
+
app = ClientApp(client_fn)
|
|
@@ -1,8 +1,12 @@
|
|
|
1
1
|
"""$project_name: A Flower / FlowerTune app."""
|
|
2
2
|
|
|
3
|
+
from flwr_datasets import FederatedDataset
|
|
4
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
3
5
|
from transformers import AutoTokenizer
|
|
4
6
|
from trl import DataCollatorForCompletionOnlyLM
|
|
5
7
|
|
|
8
|
+
FDS = None # Cache FederatedDataset
|
|
9
|
+
|
|
6
10
|
|
|
7
11
|
def formatting_prompts_func(example):
|
|
8
12
|
"""Construct prompts."""
|
|
@@ -24,7 +28,6 @@ def formatting_prompts_func(example):
|
|
|
24
28
|
|
|
25
29
|
def get_tokenizer_and_data_collator_and_propt_formatting(model_name: str):
|
|
26
30
|
"""Get tokenizer, data_collator and prompt formatting."""
|
|
27
|
-
# From: https://huggingface.co/docs/trl/en/sft_trainer
|
|
28
31
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
29
32
|
model_name, use_fast=True, padding_side="right"
|
|
30
33
|
)
|
|
@@ -49,9 +52,36 @@ def formatting(dataset):
|
|
|
49
52
|
def reformat(dataset, llm_task):
|
|
50
53
|
"""Reformat datasets."""
|
|
51
54
|
dataset = dataset.rename_column("output", "response")
|
|
52
|
-
if llm_task
|
|
55
|
+
if llm_task in ["finance", "code"]:
|
|
53
56
|
dataset = dataset.map(formatting, remove_columns=["input"])
|
|
54
57
|
if llm_task == "medical":
|
|
55
58
|
dataset = dataset.remove_columns(["instruction"])
|
|
56
59
|
dataset = dataset.rename_column("input", "instruction")
|
|
57
60
|
return dataset
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def load_data(partition_id: int, num_partitions: int, dataset_name: str):
|
|
64
|
+
"""Load partition data."""
|
|
65
|
+
# Only initialize `FederatedDataset` once
|
|
66
|
+
global FDS
|
|
67
|
+
if FDS is None:
|
|
68
|
+
partitioner = IidPartitioner(num_partitions=num_partitions)
|
|
69
|
+
FDS = FederatedDataset(
|
|
70
|
+
dataset=dataset_name,
|
|
71
|
+
partitioners={"train": partitioner},
|
|
72
|
+
)
|
|
73
|
+
client_trainset = FDS.load_partition(partition_id, "train")
|
|
74
|
+
client_trainset = reformat(client_trainset, llm_task="generalnlp")
|
|
75
|
+
return client_trainset
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def replace_keys(input_dict, match="-", target="_"):
|
|
79
|
+
"""Recursively replace match string with target string in dictionary keys."""
|
|
80
|
+
new_dict = {}
|
|
81
|
+
for key, value in input_dict.items():
|
|
82
|
+
new_key = key.replace(match, target)
|
|
83
|
+
if isinstance(value, dict):
|
|
84
|
+
new_dict[new_key] = replace_keys(value, match, target)
|
|
85
|
+
else:
|
|
86
|
+
new_dict[new_key] = value
|
|
87
|
+
return new_dict
|
|
@@ -22,9 +22,6 @@ def cosine_annealing(
|
|
|
22
22
|
|
|
23
23
|
def get_model(model_cfg: DictConfig):
|
|
24
24
|
"""Load model with appropriate quantization config and other optimizations.
|
|
25
|
-
|
|
26
|
-
Please refer to this example for `peft + BitsAndBytes`:
|
|
27
|
-
https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py
|
|
28
25
|
"""
|
|
29
26
|
if model_cfg.quantization == 4:
|
|
30
27
|
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
|
|
6
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
7
|
+
from flwr.common.config import unflatten_dict
|
|
8
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
9
|
+
from omegaconf import DictConfig
|
|
10
|
+
|
|
11
|
+
from $import_name.client_app import get_parameters, set_parameters
|
|
12
|
+
from $import_name.models import get_model
|
|
13
|
+
from $import_name.dataset import replace_keys
|
|
14
|
+
from $import_name.strategy import FlowerTuneLlm
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Get function that will be executed by the strategy's evaluate() method
|
|
18
|
+
# Here we use it to save global model checkpoints
|
|
19
|
+
def get_evaluate_fn(model_cfg, save_every_round, total_round, save_path):
|
|
20
|
+
"""Return an evaluation function for saving global model."""
|
|
21
|
+
|
|
22
|
+
def evaluate(server_round: int, parameters, config):
|
|
23
|
+
# Save model
|
|
24
|
+
if server_round != 0 and (
|
|
25
|
+
server_round == total_round or server_round % save_every_round == 0
|
|
26
|
+
):
|
|
27
|
+
# Init model
|
|
28
|
+
model = get_model(model_cfg)
|
|
29
|
+
set_parameters(model, parameters)
|
|
30
|
+
|
|
31
|
+
model.save_pretrained(f"{save_path}/peft_{server_round}")
|
|
32
|
+
|
|
33
|
+
return 0.0, {}
|
|
34
|
+
|
|
35
|
+
return evaluate
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_on_fit_config(save_path):
|
|
39
|
+
"""Return a function that will be used to construct the config that the
|
|
40
|
+
client's fit() method will receive."""
|
|
41
|
+
|
|
42
|
+
def fit_config_fn(server_round: int):
|
|
43
|
+
fit_config = {}
|
|
44
|
+
fit_config["current_round"] = server_round
|
|
45
|
+
fit_config["save_path"] = save_path
|
|
46
|
+
return fit_config
|
|
47
|
+
|
|
48
|
+
return fit_config_fn
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def fit_weighted_average(metrics):
|
|
52
|
+
"""Aggregate (federated) evaluation metrics."""
|
|
53
|
+
# Multiply accuracy of each client by number of examples used
|
|
54
|
+
losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
|
|
55
|
+
examples = [num_examples for num_examples, _ in metrics]
|
|
56
|
+
|
|
57
|
+
# Aggregate and return custom metric (weighted average)
|
|
58
|
+
return {"train_loss": sum(losses) / sum(examples)}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def server_fn(context: Context):
|
|
62
|
+
"""Construct components that set the ServerApp behaviour."""
|
|
63
|
+
# Create output directory given current timestamp
|
|
64
|
+
current_time = datetime.now()
|
|
65
|
+
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
|
|
66
|
+
save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
|
|
67
|
+
os.makedirs(save_path, exist_ok=True)
|
|
68
|
+
|
|
69
|
+
# Read from config
|
|
70
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
71
|
+
cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
|
|
72
|
+
|
|
73
|
+
# Get initial model weights
|
|
74
|
+
init_model = get_model(cfg.model)
|
|
75
|
+
init_model_parameters = get_parameters(init_model)
|
|
76
|
+
init_model_parameters = ndarrays_to_parameters(init_model_parameters)
|
|
77
|
+
|
|
78
|
+
# Define strategy
|
|
79
|
+
strategy = FlowerTuneLlm(
|
|
80
|
+
fraction_fit=cfg.strategy.fraction_fit,
|
|
81
|
+
fraction_evaluate=cfg.strategy.fraction_evaluate,
|
|
82
|
+
on_fit_config_fn=get_on_fit_config(save_path),
|
|
83
|
+
fit_metrics_aggregation_fn=fit_weighted_average,
|
|
84
|
+
initial_parameters=init_model_parameters,
|
|
85
|
+
evaluate_fn=get_evaluate_fn(
|
|
86
|
+
cfg.model, cfg.train.save_every_round, num_rounds, save_path
|
|
87
|
+
),
|
|
88
|
+
)
|
|
89
|
+
config = ServerConfig(num_rounds=num_rounds)
|
|
90
|
+
|
|
91
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Flower ServerApp
|
|
95
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""$project_name: A Flower / FlowerTune app."""
|
|
2
|
+
|
|
3
|
+
from io import BytesIO
|
|
4
|
+
from logging import INFO, WARN
|
|
5
|
+
from typing import List, Tuple, Union
|
|
6
|
+
|
|
7
|
+
from flwr.common import FitIns, FitRes, Parameters, log, parameters_to_ndarrays
|
|
8
|
+
from flwr.server.client_manager import ClientManager
|
|
9
|
+
from flwr.server.client_proxy import ClientProxy
|
|
10
|
+
from flwr.server.strategy import FedAvg
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FlowerTuneLlm(FedAvg):
|
|
14
|
+
"""Customised FedAvg strategy implementation.
|
|
15
|
+
|
|
16
|
+
This class behaves just like FedAvg but also tracks the communication
|
|
17
|
+
costs associated with `fit` over FL rounds.
|
|
18
|
+
"""
|
|
19
|
+
def __init__(self, **kwargs):
|
|
20
|
+
super().__init__(**kwargs)
|
|
21
|
+
self.comm_tracker = CommunicationTracker()
|
|
22
|
+
|
|
23
|
+
def configure_fit(
|
|
24
|
+
self, server_round: int, parameters: Parameters, client_manager: ClientManager
|
|
25
|
+
):
|
|
26
|
+
"""Configure the next round of training."""
|
|
27
|
+
return_clients = super().configure_fit(server_round, parameters, client_manager)
|
|
28
|
+
|
|
29
|
+
# Test communication costs
|
|
30
|
+
fit_ins_list = [fit_ins for _, fit_ins in return_clients]
|
|
31
|
+
self.comm_tracker.track(fit_ins_list)
|
|
32
|
+
|
|
33
|
+
return return_clients
|
|
34
|
+
|
|
35
|
+
def aggregate_fit(
|
|
36
|
+
self,
|
|
37
|
+
server_round: int,
|
|
38
|
+
results: List[Tuple[ClientProxy, FitRes]],
|
|
39
|
+
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
|
|
40
|
+
):
|
|
41
|
+
"""Aggregate fit results using weighted average."""
|
|
42
|
+
# Test communication costs
|
|
43
|
+
fit_res_list = [fit_res for _, fit_res in results]
|
|
44
|
+
self.comm_tracker.track(fit_res_list)
|
|
45
|
+
|
|
46
|
+
parameters_aggregated, metrics_aggregated = super().aggregate_fit(
|
|
47
|
+
server_round, results, failures
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return parameters_aggregated, metrics_aggregated
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class CommunicationTracker:
|
|
54
|
+
"""Communication costs tracker over FL rounds."""
|
|
55
|
+
def __init__(self):
|
|
56
|
+
self.curr_comm_cost = 0.0
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def _compute_bytes(parameters):
|
|
60
|
+
return sum([BytesIO(t).getbuffer().nbytes for t in parameters.tensors])
|
|
61
|
+
|
|
62
|
+
def track(self, fit_list: List[Union[FitIns, FitRes]]):
|
|
63
|
+
size_bytes_list = [
|
|
64
|
+
self._compute_bytes(fit_ele.parameters)
|
|
65
|
+
for fit_ele in fit_list
|
|
66
|
+
]
|
|
67
|
+
comm_cost = sum(size_bytes_list) / 1024**2
|
|
68
|
+
|
|
69
|
+
self.curr_comm_cost += comm_cost
|
|
70
|
+
log(
|
|
71
|
+
INFO,
|
|
72
|
+
"Communication budget: used %.2f MB (+%.2f MB this round) / 200,000 MB",
|
|
73
|
+
self.curr_comm_cost,
|
|
74
|
+
comm_cost,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if self.curr_comm_cost > 2e5:
|
|
78
|
+
log(
|
|
79
|
+
WARN,
|
|
80
|
+
"The accumulated communication cost has exceeded 200,000 MB. "
|
|
81
|
+
"Please consider reducing it if you plan to participate "
|
|
82
|
+
"FlowerTune LLM Leaderboard.",
|
|
83
|
+
)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""$project_name: A Flower Baseline."""
|
|
2
|
+
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Net(nn.Module):
|
|
11
|
+
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')."""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.conv1 = nn.Conv2d(3, 6, 5)
|
|
16
|
+
self.pool = nn.MaxPool2d(2, 2)
|
|
17
|
+
self.conv2 = nn.Conv2d(6, 16, 5)
|
|
18
|
+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
|
19
|
+
self.fc2 = nn.Linear(120, 84)
|
|
20
|
+
self.fc3 = nn.Linear(84, 10)
|
|
21
|
+
|
|
22
|
+
def forward(self, x):
|
|
23
|
+
"""Do forward."""
|
|
24
|
+
x = self.pool(F.relu(self.conv1(x)))
|
|
25
|
+
x = self.pool(F.relu(self.conv2(x)))
|
|
26
|
+
x = x.view(-1, 16 * 5 * 5)
|
|
27
|
+
x = F.relu(self.fc1(x))
|
|
28
|
+
x = F.relu(self.fc2(x))
|
|
29
|
+
return self.fc3(x)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def train(net, trainloader, epochs, device):
|
|
33
|
+
"""Train the model on the training set."""
|
|
34
|
+
net.to(device) # move model to GPU if available
|
|
35
|
+
criterion = torch.nn.CrossEntropyLoss()
|
|
36
|
+
criterion.to(device)
|
|
37
|
+
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
|
|
38
|
+
net.train()
|
|
39
|
+
running_loss = 0.0
|
|
40
|
+
for _ in range(epochs):
|
|
41
|
+
for batch in trainloader:
|
|
42
|
+
images = batch["img"]
|
|
43
|
+
labels = batch["label"]
|
|
44
|
+
optimizer.zero_grad()
|
|
45
|
+
loss = criterion(net(images.to(device)), labels.to(device))
|
|
46
|
+
loss.backward()
|
|
47
|
+
optimizer.step()
|
|
48
|
+
running_loss += loss.item()
|
|
49
|
+
|
|
50
|
+
avg_trainloss = running_loss / len(trainloader)
|
|
51
|
+
return avg_trainloss
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def test(net, testloader, device):
|
|
55
|
+
"""Validate the model on the test set."""
|
|
56
|
+
net.to(device)
|
|
57
|
+
criterion = torch.nn.CrossEntropyLoss()
|
|
58
|
+
correct, loss = 0, 0.0
|
|
59
|
+
with torch.no_grad():
|
|
60
|
+
for batch in testloader:
|
|
61
|
+
images = batch["img"].to(device)
|
|
62
|
+
labels = batch["label"].to(device)
|
|
63
|
+
outputs = net(images)
|
|
64
|
+
loss += criterion(outputs, labels).item()
|
|
65
|
+
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
|
|
66
|
+
accuracy = correct / len(testloader.dataset)
|
|
67
|
+
loss = loss / len(testloader)
|
|
68
|
+
return loss, accuracy
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_weights(net):
|
|
72
|
+
"""Extract model parameters as numpy arrays from state_dict."""
|
|
73
|
+
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def set_weights(net, parameters):
|
|
77
|
+
"""Apply parameters to an existing model."""
|
|
78
|
+
params_dict = zip(net.state_dict().keys(), parameters)
|
|
79
|
+
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
|
|
80
|
+
net.load_state_dict(state_dict, strict=True)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""$project_name: A Flower Baseline."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
from flwr.common import Context, Metrics, ndarrays_to_parameters
|
|
6
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
7
|
+
from flwr.server.strategy import FedAvg
|
|
8
|
+
from $import_name.model import Net, get_weights
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Define metric aggregation function
|
|
12
|
+
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
|
|
13
|
+
"""Do weighted average of accuracy metric."""
|
|
14
|
+
# Multiply accuracy of each client by number of examples used
|
|
15
|
+
accuracies = [num_examples * float(m["accuracy"]) for num_examples, m in metrics]
|
|
16
|
+
examples = [num_examples for num_examples, _ in metrics]
|
|
17
|
+
|
|
18
|
+
# Aggregate and return custom metric (weighted average)
|
|
19
|
+
return {"accuracy": sum(accuracies) / sum(examples)}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def server_fn(context: Context):
|
|
23
|
+
"""Construct components that set the ServerApp behaviour."""
|
|
24
|
+
# Read from config
|
|
25
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
26
|
+
fraction_fit = context.run_config["fraction-fit"]
|
|
27
|
+
|
|
28
|
+
# Initialize model parameters
|
|
29
|
+
ndarrays = get_weights(Net())
|
|
30
|
+
parameters = ndarrays_to_parameters(ndarrays)
|
|
31
|
+
|
|
32
|
+
# Define strategy
|
|
33
|
+
strategy = FedAvg(
|
|
34
|
+
fraction_fit=float(fraction_fit),
|
|
35
|
+
fraction_evaluate=1.0,
|
|
36
|
+
min_available_clients=2,
|
|
37
|
+
initial_parameters=parameters,
|
|
38
|
+
evaluate_metrics_aggregation_fn=weighted_average,
|
|
39
|
+
)
|
|
40
|
+
config = ServerConfig(num_rounds=int(num_rounds))
|
|
41
|
+
|
|
42
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# Create ServerApp
|
|
46
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""$project_name: A Flower Baseline."""
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""$project_name: A Flower Baseline."""
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "$package_name"
|
|
7
|
+
version = "1.0.0"
|
|
8
|
+
description = ""
|
|
9
|
+
license = "Apache-2.0"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"flwr[simulation]>=1.11.0",
|
|
12
|
+
"flwr-datasets[vision]>=0.3.0",
|
|
13
|
+
"torch==2.2.1",
|
|
14
|
+
"torchvision==0.17.1",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[tool.hatch.metadata]
|
|
18
|
+
allow-direct-references = true
|
|
19
|
+
|
|
20
|
+
[project.optional-dependencies]
|
|
21
|
+
dev = [
|
|
22
|
+
"isort==5.13.2",
|
|
23
|
+
"black==24.2.0",
|
|
24
|
+
"docformatter==1.7.5",
|
|
25
|
+
"mypy==1.8.0",
|
|
26
|
+
"pylint==3.2.6",
|
|
27
|
+
"flake8==5.0.4",
|
|
28
|
+
"pytest==6.2.4",
|
|
29
|
+
"pytest-watch==4.2.0",
|
|
30
|
+
"ruff==0.1.9",
|
|
31
|
+
"types-requests==2.31.0.20240125",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
[tool.isort]
|
|
35
|
+
profile = "black"
|
|
36
|
+
known_first_party = ["flwr"]
|
|
37
|
+
|
|
38
|
+
[tool.black]
|
|
39
|
+
line-length = 88
|
|
40
|
+
target-version = ["py38", "py39", "py310", "py311"]
|
|
41
|
+
|
|
42
|
+
[tool.pytest.ini_options]
|
|
43
|
+
minversion = "6.2"
|
|
44
|
+
addopts = "-qq"
|
|
45
|
+
testpaths = [
|
|
46
|
+
"flwr_baselines",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
[tool.mypy]
|
|
50
|
+
ignore_missing_imports = true
|
|
51
|
+
strict = false
|
|
52
|
+
plugins = "numpy.typing.mypy_plugin"
|
|
53
|
+
|
|
54
|
+
[tool.pylint."MESSAGES CONTROL"]
|
|
55
|
+
disable = "duplicate-code,too-few-public-methods,useless-import-alias"
|
|
56
|
+
good-names = "i,j,k,_,x,y,X,Y,K,N"
|
|
57
|
+
max-args = 10
|
|
58
|
+
max-attributes = 15
|
|
59
|
+
max-locals = 36
|
|
60
|
+
max-branches = 20
|
|
61
|
+
max-statements = 55
|
|
62
|
+
|
|
63
|
+
[tool.pylint.typecheck]
|
|
64
|
+
generated-members = "numpy.*, torch.*, tensorflow.*"
|
|
65
|
+
|
|
66
|
+
[[tool.mypy.overrides]]
|
|
67
|
+
module = [
|
|
68
|
+
"importlib.metadata.*",
|
|
69
|
+
"importlib_metadata.*",
|
|
70
|
+
]
|
|
71
|
+
follow_imports = "skip"
|
|
72
|
+
follow_imports_for_stubs = true
|
|
73
|
+
disallow_untyped_calls = false
|
|
74
|
+
|
|
75
|
+
[[tool.mypy.overrides]]
|
|
76
|
+
module = "torch.*"
|
|
77
|
+
follow_imports = "skip"
|
|
78
|
+
follow_imports_for_stubs = true
|
|
79
|
+
|
|
80
|
+
[tool.docformatter]
|
|
81
|
+
wrap-summaries = 88
|
|
82
|
+
wrap-descriptions = 88
|
|
83
|
+
|
|
84
|
+
[tool.ruff]
|
|
85
|
+
target-version = "py38"
|
|
86
|
+
line-length = 88
|
|
87
|
+
select = ["D", "E", "F", "W", "B", "ISC", "C4"]
|
|
88
|
+
fixable = ["D", "E", "F", "W", "B", "ISC", "C4"]
|
|
89
|
+
ignore = ["B024", "B027"]
|
|
90
|
+
exclude = [
|
|
91
|
+
".bzr",
|
|
92
|
+
".direnv",
|
|
93
|
+
".eggs",
|
|
94
|
+
".git",
|
|
95
|
+
".hg",
|
|
96
|
+
".mypy_cache",
|
|
97
|
+
".nox",
|
|
98
|
+
".pants.d",
|
|
99
|
+
".pytype",
|
|
100
|
+
".ruff_cache",
|
|
101
|
+
".svn",
|
|
102
|
+
".tox",
|
|
103
|
+
".venv",
|
|
104
|
+
"__pypackages__",
|
|
105
|
+
"_build",
|
|
106
|
+
"buck-out",
|
|
107
|
+
"build",
|
|
108
|
+
"dist",
|
|
109
|
+
"node_modules",
|
|
110
|
+
"venv",
|
|
111
|
+
"proto",
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
[tool.ruff.pydocstyle]
|
|
115
|
+
convention = "numpy"
|
|
116
|
+
|
|
117
|
+
[tool.hatch.build.targets.wheel]
|
|
118
|
+
packages = ["."]
|
|
119
|
+
|
|
120
|
+
[tool.flwr.app]
|
|
121
|
+
publisher = "$username"
|
|
122
|
+
|
|
123
|
+
[tool.flwr.app.components]
|
|
124
|
+
serverapp = "$import_name.server_app:app"
|
|
125
|
+
clientapp = "$import_name.client_app:app"
|
|
126
|
+
|
|
127
|
+
[tool.flwr.app.config]
|
|
128
|
+
num-server-rounds = 3
|
|
129
|
+
fraction-fit = 0.5
|
|
130
|
+
local-epochs = 1
|
|
131
|
+
|
|
132
|
+
[tool.flwr.federations]
|
|
133
|
+
default = "local-simulation"
|
|
134
|
+
|
|
135
|
+
[tool.flwr.federations.local-simulation]
|
|
136
|
+
options.num-supernodes = 10
|
|
137
|
+
options.backend.client-resources.num-cpus = 2
|
|
138
|
+
options.backend.client-resources.num-gpus = 0.0
|