flwr-nightly 1.11.0.dev20240822__py3-none-any.whl → 1.11.1.dev20240912__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +0 -2
- flwr/cli/build.py +1 -1
- flwr/cli/new/new.py +41 -40
- flwr/cli/new/templates/app/LICENSE.tpl +202 -0
- flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +16 -6
- flwr/cli/new/templates/app/README.md.tpl +7 -30
- flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
- flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
- flwr/cli/new/templates/app/code/flwr_tune/{client.py.tpl → client_app.py.tpl} +50 -40
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +32 -2
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -3
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +95 -0
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
- flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
- flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +34 -7
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
- flwr/cli/run/run.py +12 -2
- flwr/client/__init__.py +0 -4
- flwr/client/app.py +3 -4
- flwr/client/client.py +22 -1
- flwr/client/client_app.py +2 -2
- flwr/client/grpc_rere_client/client_interceptor.py +15 -7
- flwr/client/numpy_client.py +22 -1
- flwr/client/rest_client/connection.py +1 -1
- flwr/client/supernode/app.py +8 -7
- flwr/common/address.py +43 -0
- flwr/common/config.py +14 -11
- flwr/common/constant.py +12 -1
- flwr/common/record/recordset.py +1 -1
- flwr/common/record/typeddict.py +24 -1
- flwr/common/telemetry.py +36 -30
- flwr/server/__init__.py +0 -4
- flwr/server/app.py +27 -22
- flwr/server/compat/app.py +0 -5
- flwr/server/driver/grpc_driver.py +3 -6
- flwr/server/run_serverapp.py +20 -7
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +15 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +5 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +13 -12
- flwr/server/superlink/fleet/rest_rere/rest_api.py +71 -122
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +33 -15
- flwr/server/superlink/fleet/vce/vce_api.py +2 -6
- flwr/server/superlink/state/in_memory_state.py +15 -15
- flwr/server/superlink/state/sqlite_state.py +10 -10
- flwr/server/superlink/state/state.py +8 -8
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +1 -0
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +1 -0
- flwr/simulation/ray_transport/ray_actor.py +2 -2
- flwr/simulation/run_simulation.py +85 -25
- flwr/superexec/__init__.py +0 -6
- flwr/superexec/app.py +5 -3
- flwr/superexec/deployment.py +2 -2
- flwr/superexec/simulation.py +20 -1
- {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/METADATA +3 -3
- {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/RECORD +70 -62
- flwr_nightly-1.11.1.dev20240912.dist-info/entry_points.txt +10 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +0 -89
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +0 -34
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +0 -48
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +0 -11
- flwr_nightly-1.11.0.dev20240822.dist-info/entry_points.txt +0 -10
- {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240822.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""$project_name: A Flower Baseline."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
from flwr.common import Context, Metrics, ndarrays_to_parameters
|
|
6
|
+
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
7
|
+
from flwr.server.strategy import FedAvg
|
|
8
|
+
from $import_name.model import Net, get_weights
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Define metric aggregation function
|
|
12
|
+
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
|
|
13
|
+
"""Do weighted average of accuracy metric."""
|
|
14
|
+
# Multiply accuracy of each client by number of examples used
|
|
15
|
+
accuracies = [num_examples * float(m["accuracy"]) for num_examples, m in metrics]
|
|
16
|
+
examples = [num_examples for num_examples, _ in metrics]
|
|
17
|
+
|
|
18
|
+
# Aggregate and return custom metric (weighted average)
|
|
19
|
+
return {"accuracy": sum(accuracies) / sum(examples)}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def server_fn(context: Context):
|
|
23
|
+
"""Construct components that set the ServerApp behaviour."""
|
|
24
|
+
# Read from config
|
|
25
|
+
num_rounds = context.run_config["num-server-rounds"]
|
|
26
|
+
fraction_fit = context.run_config["fraction-fit"]
|
|
27
|
+
|
|
28
|
+
# Initialize model parameters
|
|
29
|
+
ndarrays = get_weights(Net())
|
|
30
|
+
parameters = ndarrays_to_parameters(ndarrays)
|
|
31
|
+
|
|
32
|
+
# Define strategy
|
|
33
|
+
strategy = FedAvg(
|
|
34
|
+
fraction_fit=float(fraction_fit),
|
|
35
|
+
fraction_evaluate=1.0,
|
|
36
|
+
min_available_clients=2,
|
|
37
|
+
initial_parameters=parameters,
|
|
38
|
+
evaluate_metrics_aggregation_fn=weighted_average,
|
|
39
|
+
)
|
|
40
|
+
config = ServerConfig(num_rounds=int(num_rounds))
|
|
41
|
+
|
|
42
|
+
return ServerAppComponents(strategy=strategy, config=config)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# Create ServerApp
|
|
46
|
+
app = ServerApp(server_fn=server_fn)
|
|
@@ -1,18 +1,33 @@
|
|
|
1
1
|
"""$project_name: A Flower / $framework_str app."""
|
|
2
2
|
|
|
3
|
-
from flwr.common import Context
|
|
4
|
-
from flwr.server.strategy import FedAvg
|
|
3
|
+
from flwr.common import Context, ndarrays_to_parameters
|
|
5
4
|
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
|
|
5
|
+
from flwr.server.strategy import FedAvg
|
|
6
|
+
from transformers import AutoModelForSequenceClassification
|
|
7
|
+
|
|
8
|
+
from $import_name.task import get_weights
|
|
6
9
|
|
|
7
10
|
|
|
8
11
|
def server_fn(context: Context):
|
|
9
12
|
# Read from config
|
|
10
13
|
num_rounds = context.run_config["num-server-rounds"]
|
|
14
|
+
fraction_fit = context.run_config["fraction-fit"]
|
|
15
|
+
|
|
16
|
+
# Initialize global model
|
|
17
|
+
model_name = context.run_config["model-name"]
|
|
18
|
+
num_labels = context.run_config["num-labels"]
|
|
19
|
+
net = AutoModelForSequenceClassification.from_pretrained(
|
|
20
|
+
model_name, num_labels=num_labels
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
weights = get_weights(net)
|
|
24
|
+
initial_parameters = ndarrays_to_parameters(weights)
|
|
11
25
|
|
|
12
26
|
# Define strategy
|
|
13
27
|
strategy = FedAvg(
|
|
14
|
-
fraction_fit=
|
|
28
|
+
fraction_fit=fraction_fit,
|
|
15
29
|
fraction_evaluate=1.0,
|
|
30
|
+
initial_parameters=initial_parameters,
|
|
16
31
|
)
|
|
17
32
|
config = ServerConfig(num_rounds=num_rounds)
|
|
18
33
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""$project_name: A Flower Baseline."""
|
|
@@ -4,24 +4,25 @@ import warnings
|
|
|
4
4
|
from collections import OrderedDict
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
+
import transformers
|
|
8
|
+
from datasets.utils.logging import disable_progress_bar
|
|
7
9
|
from evaluate import load as load_metric
|
|
10
|
+
from flwr_datasets import FederatedDataset
|
|
11
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
8
12
|
from torch.optim import AdamW
|
|
9
13
|
from torch.utils.data import DataLoader
|
|
10
14
|
from transformers import AutoTokenizer, DataCollatorWithPadding
|
|
11
15
|
|
|
12
|
-
from flwr_datasets import FederatedDataset
|
|
13
|
-
from flwr_datasets.partitioner import IidPartitioner
|
|
14
|
-
|
|
15
|
-
|
|
16
16
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
18
|
+
disable_progress_bar()
|
|
19
|
+
transformers.logging.set_verbosity_error()
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
fds = None # Cache FederatedDataset
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
def load_data(partition_id: int, num_partitions: int):
|
|
25
|
+
def load_data(partition_id: int, num_partitions: int, model_name: str):
|
|
25
26
|
"""Load IMDB data (training and eval)"""
|
|
26
27
|
# Only initialize `FederatedDataset` once
|
|
27
28
|
global fds
|
|
@@ -35,10 +36,12 @@ def load_data(partition_id: int, num_partitions: int):
|
|
|
35
36
|
# Divide data: 80% train, 20% test
|
|
36
37
|
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
37
38
|
|
|
38
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
39
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
39
40
|
|
|
40
41
|
def tokenize_function(examples):
|
|
41
|
-
return tokenizer(
|
|
42
|
+
return tokenizer(
|
|
43
|
+
examples["text"], truncation=True, add_special_tokens=True, max_length=512
|
|
44
|
+
)
|
|
42
45
|
|
|
43
46
|
partition_train_test = partition_train_test.map(tokenize_function, batched=True)
|
|
44
47
|
partition_train_test = partition_train_test.remove_columns("text")
|
|
@@ -59,12 +62,12 @@ def load_data(partition_id: int, num_partitions: int):
|
|
|
59
62
|
return trainloader, testloader
|
|
60
63
|
|
|
61
64
|
|
|
62
|
-
def train(net, trainloader, epochs):
|
|
65
|
+
def train(net, trainloader, epochs, device):
|
|
63
66
|
optimizer = AdamW(net.parameters(), lr=5e-5)
|
|
64
67
|
net.train()
|
|
65
68
|
for _ in range(epochs):
|
|
66
69
|
for batch in trainloader:
|
|
67
|
-
batch = {k: v.to(
|
|
70
|
+
batch = {k: v.to(device) for k, v in batch.items()}
|
|
68
71
|
outputs = net(**batch)
|
|
69
72
|
loss = outputs.loss
|
|
70
73
|
loss.backward()
|
|
@@ -72,12 +75,12 @@ def train(net, trainloader, epochs):
|
|
|
72
75
|
optimizer.zero_grad()
|
|
73
76
|
|
|
74
77
|
|
|
75
|
-
def test(net, testloader):
|
|
78
|
+
def test(net, testloader, device):
|
|
76
79
|
metric = load_metric("accuracy")
|
|
77
80
|
loss = 0
|
|
78
81
|
net.eval()
|
|
79
82
|
for batch in testloader:
|
|
80
|
-
batch = {k: v.to(
|
|
83
|
+
batch = {k: v.to(device) for k, v in batch.items()}
|
|
81
84
|
with torch.no_grad():
|
|
82
85
|
outputs = net(**batch)
|
|
83
86
|
logits = outputs.logits
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""$project_name: A Flower Baseline."""
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "$package_name"
|
|
7
|
+
version = "1.0.0"
|
|
8
|
+
description = ""
|
|
9
|
+
license = "Apache-2.0"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"flwr[simulation]>=1.11.0",
|
|
12
|
+
"flwr-datasets[vision]>=0.3.0",
|
|
13
|
+
"torch==2.2.1",
|
|
14
|
+
"torchvision==0.17.1",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[tool.hatch.metadata]
|
|
18
|
+
allow-direct-references = true
|
|
19
|
+
|
|
20
|
+
[project.optional-dependencies]
|
|
21
|
+
dev = [
|
|
22
|
+
"isort==5.13.2",
|
|
23
|
+
"black==24.2.0",
|
|
24
|
+
"docformatter==1.7.5",
|
|
25
|
+
"mypy==1.8.0",
|
|
26
|
+
"pylint==3.2.6",
|
|
27
|
+
"flake8==5.0.4",
|
|
28
|
+
"pytest==6.2.4",
|
|
29
|
+
"pytest-watch==4.2.0",
|
|
30
|
+
"ruff==0.1.9",
|
|
31
|
+
"types-requests==2.31.0.20240125",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
[tool.isort]
|
|
35
|
+
profile = "black"
|
|
36
|
+
known_first_party = ["flwr"]
|
|
37
|
+
|
|
38
|
+
[tool.black]
|
|
39
|
+
line-length = 88
|
|
40
|
+
target-version = ["py38", "py39", "py310", "py311"]
|
|
41
|
+
|
|
42
|
+
[tool.pytest.ini_options]
|
|
43
|
+
minversion = "6.2"
|
|
44
|
+
addopts = "-qq"
|
|
45
|
+
testpaths = [
|
|
46
|
+
"flwr_baselines",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
[tool.mypy]
|
|
50
|
+
ignore_missing_imports = true
|
|
51
|
+
strict = false
|
|
52
|
+
plugins = "numpy.typing.mypy_plugin"
|
|
53
|
+
|
|
54
|
+
[tool.pylint."MESSAGES CONTROL"]
|
|
55
|
+
disable = "duplicate-code,too-few-public-methods,useless-import-alias"
|
|
56
|
+
good-names = "i,j,k,_,x,y,X,Y,K,N"
|
|
57
|
+
max-args = 10
|
|
58
|
+
max-attributes = 15
|
|
59
|
+
max-locals = 36
|
|
60
|
+
max-branches = 20
|
|
61
|
+
max-statements = 55
|
|
62
|
+
|
|
63
|
+
[tool.pylint.typecheck]
|
|
64
|
+
generated-members = "numpy.*, torch.*, tensorflow.*"
|
|
65
|
+
|
|
66
|
+
[[tool.mypy.overrides]]
|
|
67
|
+
module = [
|
|
68
|
+
"importlib.metadata.*",
|
|
69
|
+
"importlib_metadata.*",
|
|
70
|
+
]
|
|
71
|
+
follow_imports = "skip"
|
|
72
|
+
follow_imports_for_stubs = true
|
|
73
|
+
disallow_untyped_calls = false
|
|
74
|
+
|
|
75
|
+
[[tool.mypy.overrides]]
|
|
76
|
+
module = "torch.*"
|
|
77
|
+
follow_imports = "skip"
|
|
78
|
+
follow_imports_for_stubs = true
|
|
79
|
+
|
|
80
|
+
[tool.docformatter]
|
|
81
|
+
wrap-summaries = 88
|
|
82
|
+
wrap-descriptions = 88
|
|
83
|
+
|
|
84
|
+
[tool.ruff]
|
|
85
|
+
target-version = "py38"
|
|
86
|
+
line-length = 88
|
|
87
|
+
select = ["D", "E", "F", "W", "B", "ISC", "C4"]
|
|
88
|
+
fixable = ["D", "E", "F", "W", "B", "ISC", "C4"]
|
|
89
|
+
ignore = ["B024", "B027"]
|
|
90
|
+
exclude = [
|
|
91
|
+
".bzr",
|
|
92
|
+
".direnv",
|
|
93
|
+
".eggs",
|
|
94
|
+
".git",
|
|
95
|
+
".hg",
|
|
96
|
+
".mypy_cache",
|
|
97
|
+
".nox",
|
|
98
|
+
".pants.d",
|
|
99
|
+
".pytype",
|
|
100
|
+
".ruff_cache",
|
|
101
|
+
".svn",
|
|
102
|
+
".tox",
|
|
103
|
+
".venv",
|
|
104
|
+
"__pypackages__",
|
|
105
|
+
"_build",
|
|
106
|
+
"buck-out",
|
|
107
|
+
"build",
|
|
108
|
+
"dist",
|
|
109
|
+
"node_modules",
|
|
110
|
+
"venv",
|
|
111
|
+
"proto",
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
[tool.ruff.pydocstyle]
|
|
115
|
+
convention = "numpy"
|
|
116
|
+
|
|
117
|
+
[tool.hatch.build.targets.wheel]
|
|
118
|
+
packages = ["."]
|
|
119
|
+
|
|
120
|
+
[tool.flwr.app]
|
|
121
|
+
publisher = "$username"
|
|
122
|
+
|
|
123
|
+
[tool.flwr.app.components]
|
|
124
|
+
serverapp = "$import_name.server_app:app"
|
|
125
|
+
clientapp = "$import_name.client_app:app"
|
|
126
|
+
|
|
127
|
+
[tool.flwr.app.config]
|
|
128
|
+
num-server-rounds = 3
|
|
129
|
+
fraction-fit = 0.5
|
|
130
|
+
local-epochs = 1
|
|
131
|
+
|
|
132
|
+
[tool.flwr.federations]
|
|
133
|
+
default = "local-simulation"
|
|
134
|
+
|
|
135
|
+
[tool.flwr.federations.local-simulation]
|
|
136
|
+
options.num-supernodes = 10
|
|
137
|
+
options.backend.client-resources.num-cpus = 2
|
|
138
|
+
options.backend.client-resources.num-gpus = 0.0
|
|
@@ -8,15 +8,15 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
12
|
-
"flwr-datasets>=0.
|
|
13
|
-
"hydra-core==1.3.2",
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
|
+
"flwr-datasets>=0.3.0",
|
|
14
13
|
"trl==0.8.1",
|
|
15
14
|
"bitsandbytes==0.43.0",
|
|
16
15
|
"scipy==1.13.0",
|
|
17
16
|
"peft==0.6.2",
|
|
18
17
|
"transformers==4.39.3",
|
|
19
18
|
"sentencepiece==0.2.0",
|
|
19
|
+
"omegaconf==2.3.0",
|
|
20
20
|
]
|
|
21
21
|
|
|
22
22
|
[tool.hatch.build.targets.wheel]
|
|
@@ -26,14 +26,41 @@ packages = ["."]
|
|
|
26
26
|
publisher = "$username"
|
|
27
27
|
|
|
28
28
|
[tool.flwr.app.components]
|
|
29
|
-
serverapp = "$import_name.app
|
|
30
|
-
clientapp = "$import_name.app
|
|
29
|
+
serverapp = "$import_name.server_app:app"
|
|
30
|
+
clientapp = "$import_name.client_app:app"
|
|
31
31
|
|
|
32
32
|
[tool.flwr.app.config]
|
|
33
|
-
|
|
33
|
+
model.name = "mistralai/Mistral-7B-v0.3"
|
|
34
|
+
model.quantization = 4
|
|
35
|
+
model.gradient-checkpointing = true
|
|
36
|
+
model.lora.peft-lora-r = 32
|
|
37
|
+
model.lora.peft-lora-alpha = 64
|
|
38
|
+
train.save-every-round = 5
|
|
39
|
+
train.learning-rate-max = 5e-5
|
|
40
|
+
train.learning-rate-min = 1e-6
|
|
41
|
+
train.seq-length = 512
|
|
42
|
+
train.training-arguments.output-dir = ""
|
|
43
|
+
train.training-arguments.learning-rate = ""
|
|
44
|
+
train.training-arguments.per-device-train-batch-size = 16
|
|
45
|
+
train.training-arguments.gradient-accumulation-steps = 1
|
|
46
|
+
train.training-arguments.logging-steps = 10
|
|
47
|
+
train.training-arguments.num-train-epochs = 3
|
|
48
|
+
train.training-arguments.max-steps = 10
|
|
49
|
+
train.training-arguments.save-steps = 1000
|
|
50
|
+
train.training-arguments.save-total-limit = 10
|
|
51
|
+
train.training-arguments.gradient-checkpointing = true
|
|
52
|
+
train.training-arguments.lr-scheduler-type = "constant"
|
|
53
|
+
strategy.fraction-fit = $fraction_fit
|
|
54
|
+
strategy.fraction-evaluate = 0.0
|
|
55
|
+
num-server-rounds = 200
|
|
56
|
+
|
|
57
|
+
[tool.flwr.app.config.static]
|
|
58
|
+
dataset.name = "$dataset_name"
|
|
34
59
|
|
|
35
60
|
[tool.flwr.federations]
|
|
36
61
|
default = "local-simulation"
|
|
37
62
|
|
|
38
63
|
[tool.flwr.federations.local-simulation]
|
|
39
|
-
options.num-supernodes =
|
|
64
|
+
options.num-supernodes = $num_clients
|
|
65
|
+
options.backend.client-resources.num-cpus = 6
|
|
66
|
+
options.backend.client-resources.num-gpus = 1.0
|
|
@@ -8,7 +8,7 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
11
|
+
"flwr[simulation]>=1.11.0",
|
|
12
12
|
"flwr-datasets>=0.3.0",
|
|
13
13
|
"torch==2.2.1",
|
|
14
14
|
"transformers>=4.30.0,<5.0",
|
|
@@ -29,10 +29,18 @@ clientapp = "$import_name.client_app:app"
|
|
|
29
29
|
|
|
30
30
|
[tool.flwr.app.config]
|
|
31
31
|
num-server-rounds = 3
|
|
32
|
+
fraction-fit = 0.5
|
|
32
33
|
local-epochs = 1
|
|
34
|
+
model-name = "prajjwal1/bert-tiny" # Set a larger model if you have access to more GPU resources
|
|
35
|
+
num-labels = 2
|
|
33
36
|
|
|
34
37
|
[tool.flwr.federations]
|
|
35
38
|
default = "localhost"
|
|
36
39
|
|
|
37
40
|
[tool.flwr.federations.localhost]
|
|
38
41
|
options.num-supernodes = 10
|
|
42
|
+
|
|
43
|
+
[tool.flwr.federations.localhost-gpu]
|
|
44
|
+
options.num-supernodes = 10
|
|
45
|
+
options.backend.client-resources.num-cpus = 4 # each ClientApp assumes to use 4CPUs
|
|
46
|
+
options.backend.client-resources.num-gpus = 0.25 # at most 4 ClientApps will run in a given GPU
|
flwr/cli/run/run.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Flower command line interface `run` command."""
|
|
16
16
|
|
|
17
17
|
import hashlib
|
|
18
|
+
import json
|
|
18
19
|
import subprocess
|
|
19
20
|
import sys
|
|
20
21
|
from logging import DEBUG
|
|
@@ -123,14 +124,14 @@ def run(
|
|
|
123
124
|
|
|
124
125
|
|
|
125
126
|
def _run_with_superexec(
|
|
126
|
-
app:
|
|
127
|
+
app: Path,
|
|
127
128
|
federation_config: Dict[str, Any],
|
|
128
129
|
config_overrides: Optional[List[str]],
|
|
129
130
|
) -> None:
|
|
130
131
|
|
|
131
132
|
insecure_str = federation_config.get("insecure")
|
|
132
133
|
if root_certificates := federation_config.get("root-certificates"):
|
|
133
|
-
root_certificates_bytes =
|
|
134
|
+
root_certificates_bytes = (app / root_certificates).read_bytes()
|
|
134
135
|
if insecure := bool(insecure_str):
|
|
135
136
|
typer.secho(
|
|
136
137
|
"❌ `root_certificates` were provided but the `insecure` parameter"
|
|
@@ -192,6 +193,8 @@ def _run_without_superexec(
|
|
|
192
193
|
) -> None:
|
|
193
194
|
try:
|
|
194
195
|
num_supernodes = federation_config["options"]["num-supernodes"]
|
|
196
|
+
verbose: Optional[bool] = federation_config["options"].get("verbose")
|
|
197
|
+
backend_cfg = federation_config["options"].get("backend", {})
|
|
195
198
|
except KeyError as err:
|
|
196
199
|
typer.secho(
|
|
197
200
|
"❌ The project's `pyproject.toml` needs to declare the number of"
|
|
@@ -212,6 +215,13 @@ def _run_without_superexec(
|
|
|
212
215
|
f"{num_supernodes}",
|
|
213
216
|
]
|
|
214
217
|
|
|
218
|
+
if backend_cfg:
|
|
219
|
+
# Stringify as JSON
|
|
220
|
+
command.extend(["--backend-config", json.dumps(backend_cfg)])
|
|
221
|
+
|
|
222
|
+
if verbose:
|
|
223
|
+
command.extend(["--verbose"])
|
|
224
|
+
|
|
215
225
|
if config_overrides:
|
|
216
226
|
command.extend(["--run-config", f"{' '.join(config_overrides)}"])
|
|
217
227
|
|
flwr/client/__init__.py
CHANGED
|
@@ -20,8 +20,6 @@ from .app import start_numpy_client as start_numpy_client
|
|
|
20
20
|
from .client import Client as Client
|
|
21
21
|
from .client_app import ClientApp as ClientApp
|
|
22
22
|
from .numpy_client import NumPyClient as NumPyClient
|
|
23
|
-
from .supernode import run_client_app as run_client_app
|
|
24
|
-
from .supernode import run_supernode as run_supernode
|
|
25
23
|
from .typing import ClientFn as ClientFn
|
|
26
24
|
from .typing import ClientFnExt as ClientFnExt
|
|
27
25
|
|
|
@@ -32,8 +30,6 @@ __all__ = [
|
|
|
32
30
|
"ClientFnExt",
|
|
33
31
|
"NumPyClient",
|
|
34
32
|
"mod",
|
|
35
|
-
"run_client_app",
|
|
36
|
-
"run_supernode",
|
|
37
33
|
"start_client",
|
|
38
34
|
"start_numpy_client",
|
|
39
35
|
]
|
flwr/client/app.py
CHANGED
|
@@ -35,6 +35,7 @@ from flwr.client.typing import ClientFnExt
|
|
|
35
35
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
|
|
36
36
|
from flwr.common.address import parse_address
|
|
37
37
|
from flwr.common.constant import (
|
|
38
|
+
CLIENTAPPIO_API_DEFAULT_ADDRESS,
|
|
38
39
|
MISSING_EXTRA_REST,
|
|
39
40
|
RUN_ID_NUM_BYTES,
|
|
40
41
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
@@ -60,8 +61,6 @@ from .message_handler.message_handler import handle_control_message
|
|
|
60
61
|
from .node_state import NodeState
|
|
61
62
|
from .numpy_client import NumPyClient
|
|
62
63
|
|
|
63
|
-
ADDRESS_CLIENTAPPIO_API_GRPC_RERE = "0.0.0.0:9094"
|
|
64
|
-
|
|
65
64
|
ISOLATION_MODE_SUBPROCESS = "subprocess"
|
|
66
65
|
ISOLATION_MODE_PROCESS = "process"
|
|
67
66
|
|
|
@@ -211,7 +210,7 @@ def start_client_internal(
|
|
|
211
210
|
max_wait_time: Optional[float] = None,
|
|
212
211
|
flwr_path: Optional[Path] = None,
|
|
213
212
|
isolation: Optional[str] = None,
|
|
214
|
-
supernode_address: Optional[str] =
|
|
213
|
+
supernode_address: Optional[str] = CLIENTAPPIO_API_DEFAULT_ADDRESS,
|
|
215
214
|
) -> None:
|
|
216
215
|
"""Start a Flower client node which connects to a Flower server.
|
|
217
216
|
|
|
@@ -266,7 +265,7 @@ def start_client_internal(
|
|
|
266
265
|
by the SueprNode and communicates using gRPC at the address
|
|
267
266
|
`supernode_address`. If `process`, the `ClientApp` runs in a separate isolated
|
|
268
267
|
process and communicates using gRPC at the address `supernode_address`.
|
|
269
|
-
supernode_address : Optional[str] (default: `
|
|
268
|
+
supernode_address : Optional[str] (default: `CLIENTAPPIO_API_DEFAULT_ADDRESS`)
|
|
270
269
|
The SuperNode gRPC server address.
|
|
271
270
|
"""
|
|
272
271
|
if insecure is None:
|
flwr/client/client.py
CHANGED
|
@@ -33,12 +33,13 @@ from flwr.common import (
|
|
|
33
33
|
Parameters,
|
|
34
34
|
Status,
|
|
35
35
|
)
|
|
36
|
+
from flwr.common.logger import warn_deprecated_feature_with_example
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
class Client(ABC):
|
|
39
40
|
"""Abstract base class for Flower clients."""
|
|
40
41
|
|
|
41
|
-
|
|
42
|
+
_context: Context
|
|
42
43
|
|
|
43
44
|
def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
|
|
44
45
|
"""Return set of client's properties.
|
|
@@ -141,6 +142,26 @@ class Client(ABC):
|
|
|
141
142
|
metrics={},
|
|
142
143
|
)
|
|
143
144
|
|
|
145
|
+
@property
|
|
146
|
+
def context(self) -> Context:
|
|
147
|
+
"""Getter for `Context` client attribute."""
|
|
148
|
+
warn_deprecated_feature_with_example(
|
|
149
|
+
"Accessing the context via the client's attribute is deprecated.",
|
|
150
|
+
example_message="Instead, pass it to the client's "
|
|
151
|
+
"constructor in your `client_fn()` which already "
|
|
152
|
+
"receives a context object.",
|
|
153
|
+
code_example="def client_fn(context: Context) -> Client:\n\n"
|
|
154
|
+
"\t\t# Your existing client_fn\n\n"
|
|
155
|
+
"\t\t# Pass `context` to the constructor\n"
|
|
156
|
+
"\t\treturn FlowerClient(context).to_client()",
|
|
157
|
+
)
|
|
158
|
+
return self._context
|
|
159
|
+
|
|
160
|
+
@context.setter
|
|
161
|
+
def context(self, context: Context) -> None:
|
|
162
|
+
"""Setter for `Context` client attribute."""
|
|
163
|
+
self._context = context
|
|
164
|
+
|
|
144
165
|
def get_context(self) -> Context:
|
|
145
166
|
"""Get the run context from this client."""
|
|
146
167
|
return self.context
|
flwr/client/client_app.py
CHANGED
|
@@ -41,11 +41,11 @@ def _alert_erroneous_client_fn() -> None:
|
|
|
41
41
|
|
|
42
42
|
def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
|
|
43
43
|
client_fn_args = inspect.signature(client_fn).parameters
|
|
44
|
-
first_arg = list(client_fn_args.keys())[0]
|
|
45
44
|
|
|
46
45
|
if len(client_fn_args) != 1:
|
|
47
46
|
_alert_erroneous_client_fn()
|
|
48
47
|
|
|
48
|
+
first_arg = list(client_fn_args.keys())[0]
|
|
49
49
|
first_arg_type = client_fn_args[first_arg].annotation
|
|
50
50
|
|
|
51
51
|
if first_arg_type is str or first_arg == "cid":
|
|
@@ -263,7 +263,7 @@ def _registration_error(fn_name: str) -> ValueError:
|
|
|
263
263
|
>>> class FlowerClient(NumPyClient):
|
|
264
264
|
>>> # ...
|
|
265
265
|
>>>
|
|
266
|
-
>>> def client_fn(
|
|
266
|
+
>>> def client_fn(context: Context):
|
|
267
267
|
>>> return FlowerClient().to_client()
|
|
268
268
|
>>>
|
|
269
269
|
>>> app = ClientApp(
|
|
@@ -17,11 +17,13 @@
|
|
|
17
17
|
|
|
18
18
|
import base64
|
|
19
19
|
import collections
|
|
20
|
+
from logging import WARNING
|
|
20
21
|
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
21
22
|
|
|
22
23
|
import grpc
|
|
23
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
25
|
|
|
26
|
+
from flwr.common.logger import log
|
|
25
27
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
26
28
|
bytes_to_public_key,
|
|
27
29
|
compute_hmac,
|
|
@@ -128,13 +130,12 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
|
|
|
128
130
|
if self.shared_secret is None:
|
|
129
131
|
raise RuntimeError("Failure to compute hmac")
|
|
130
132
|
|
|
133
|
+
message_bytes = request.SerializeToString(deterministic=True)
|
|
131
134
|
metadata.append(
|
|
132
135
|
(
|
|
133
136
|
_AUTH_TOKEN_HEADER,
|
|
134
137
|
base64.urlsafe_b64encode(
|
|
135
|
-
compute_hmac(
|
|
136
|
-
self.shared_secret, request.SerializeToString(True)
|
|
137
|
-
)
|
|
138
|
+
compute_hmac(self.shared_secret, message_bytes)
|
|
138
139
|
),
|
|
139
140
|
)
|
|
140
141
|
)
|
|
@@ -151,8 +152,15 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
|
|
|
151
152
|
server_public_key_bytes = base64.urlsafe_b64decode(
|
|
152
153
|
_get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
|
|
153
154
|
)
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
self.
|
|
157
|
-
|
|
155
|
+
|
|
156
|
+
if server_public_key_bytes != b"":
|
|
157
|
+
self.server_public_key = bytes_to_public_key(server_public_key_bytes)
|
|
158
|
+
else:
|
|
159
|
+
log(WARNING, "Can't get server public key, SuperLink may be offline")
|
|
160
|
+
|
|
161
|
+
if self.server_public_key is not None:
|
|
162
|
+
self.shared_secret = generate_shared_key(
|
|
163
|
+
self.private_key, self.server_public_key
|
|
164
|
+
)
|
|
165
|
+
|
|
158
166
|
return response
|
flwr/client/numpy_client.py
CHANGED
|
@@ -27,6 +27,7 @@ from flwr.common import (
|
|
|
27
27
|
ndarrays_to_parameters,
|
|
28
28
|
parameters_to_ndarrays,
|
|
29
29
|
)
|
|
30
|
+
from flwr.common.logger import warn_deprecated_feature_with_example
|
|
30
31
|
from flwr.common.typing import (
|
|
31
32
|
Code,
|
|
32
33
|
EvaluateIns,
|
|
@@ -70,7 +71,7 @@ Example
|
|
|
70
71
|
class NumPyClient(ABC):
|
|
71
72
|
"""Abstract base class for Flower clients using NumPy."""
|
|
72
73
|
|
|
73
|
-
|
|
74
|
+
_context: Context
|
|
74
75
|
|
|
75
76
|
def get_properties(self, config: Config) -> Dict[str, Scalar]:
|
|
76
77
|
"""Return a client's set of properties.
|
|
@@ -174,6 +175,26 @@ class NumPyClient(ABC):
|
|
|
174
175
|
_ = (self, parameters, config)
|
|
175
176
|
return 0.0, 0, {}
|
|
176
177
|
|
|
178
|
+
@property
|
|
179
|
+
def context(self) -> Context:
|
|
180
|
+
"""Getter for `Context` client attribute."""
|
|
181
|
+
warn_deprecated_feature_with_example(
|
|
182
|
+
"Accessing the context via the client's attribute is deprecated.",
|
|
183
|
+
example_message="Instead, pass it to the client's "
|
|
184
|
+
"constructor in your `client_fn()` which already "
|
|
185
|
+
"receives a context object.",
|
|
186
|
+
code_example="def client_fn(context: Context) -> Client:\n\n"
|
|
187
|
+
"\t\t# Your existing client_fn\n\n"
|
|
188
|
+
"\t\t# Pass `context` to the constructor\n"
|
|
189
|
+
"\t\treturn FlowerClient(context).to_client()",
|
|
190
|
+
)
|
|
191
|
+
return self._context
|
|
192
|
+
|
|
193
|
+
@context.setter
|
|
194
|
+
def context(self, context: Context) -> None:
|
|
195
|
+
"""Setter for `Context` client attribute."""
|
|
196
|
+
self._context = context
|
|
197
|
+
|
|
177
198
|
def get_context(self) -> Context:
|
|
178
199
|
"""Get the run context from this client."""
|
|
179
200
|
return self.context
|
|
@@ -275,7 +275,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
|
|
|
275
275
|
req = DeleteNodeRequest(node=node)
|
|
276
276
|
|
|
277
277
|
# Send the request
|
|
278
|
-
res = _request(req, DeleteNodeResponse,
|
|
278
|
+
res = _request(req, DeleteNodeResponse, PATH_DELETE_NODE)
|
|
279
279
|
if res is None:
|
|
280
280
|
return
|
|
281
281
|
|