flwr-nightly 1.11.0.dev20240823__py3-none-any.whl → 1.11.1.dev20240912__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +0 -2
- flwr/cli/new/new.py +41 -40
- flwr/cli/new/templates/app/LICENSE.tpl +202 -0
- flwr/cli/new/templates/app/README.baseline.md.tpl +127 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +16 -6
- flwr/cli/new/templates/app/code/__init__.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +58 -0
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +19 -29
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +0 -3
- flwr/cli/new/templates/app/code/dataset.baseline.py.tpl +36 -0
- flwr/cli/new/templates/app/code/flwr_tune/{client.py.tpl → client_app.py.tpl} +50 -40
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +32 -2
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +0 -3
- flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +95 -0
- flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +83 -0
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +80 -0
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +46 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +18 -3
- flwr/cli/new/templates/app/code/strategy.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -13
- flwr/cli/new/templates/app/code/utils.baseline.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +138 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +34 -7
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +9 -1
- flwr/cli/run/run.py +2 -2
- flwr/client/__init__.py +0 -4
- flwr/client/app.py +3 -4
- flwr/client/client_app.py +2 -2
- flwr/client/grpc_rere_client/client_interceptor.py +15 -7
- flwr/client/supernode/app.py +8 -7
- flwr/common/config.py +14 -11
- flwr/common/constant.py +12 -1
- flwr/common/record/recordset.py +1 -1
- flwr/common/record/typeddict.py +24 -1
- flwr/common/telemetry.py +36 -30
- flwr/server/__init__.py +0 -4
- flwr/server/app.py +21 -22
- flwr/server/compat/app.py +0 -5
- flwr/server/driver/grpc_driver.py +3 -6
- flwr/server/run_serverapp.py +20 -7
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +19 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +13 -12
- flwr/server/superlink/fleet/vce/backend/raybackend.py +21 -12
- flwr/server/superlink/state/in_memory_state.py +15 -15
- flwr/server/superlink/state/sqlite_state.py +10 -10
- flwr/server/superlink/state/state.py +8 -8
- flwr/simulation/ray_transport/ray_actor.py +2 -2
- flwr/simulation/run_simulation.py +37 -8
- flwr/superexec/__init__.py +0 -6
- flwr/superexec/app.py +5 -3
- flwr/superexec/deployment.py +2 -2
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/METADATA +3 -3
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/RECORD +56 -48
- flwr_nightly-1.11.1.dev20240912.dist-info/entry_points.txt +10 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +0 -89
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +0 -34
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +0 -48
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +0 -11
- flwr_nightly-1.11.0.dev20240823.dist-info/entry_points.txt +0 -10
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240823.dist-info → flwr_nightly-1.11.1.dev20240912.dist-info}/WHEEL +0 -0
|
@@ -4,24 +4,25 @@ import warnings
|
|
|
4
4
|
from collections import OrderedDict
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
+
import transformers
|
|
8
|
+
from datasets.utils.logging import disable_progress_bar
|
|
7
9
|
from evaluate import load as load_metric
|
|
10
|
+
from flwr_datasets import FederatedDataset
|
|
11
|
+
from flwr_datasets.partitioner import IidPartitioner
|
|
8
12
|
from torch.optim import AdamW
|
|
9
13
|
from torch.utils.data import DataLoader
|
|
10
14
|
from transformers import AutoTokenizer, DataCollatorWithPadding
|
|
11
15
|
|
|
12
|
-
from flwr_datasets import FederatedDataset
|
|
13
|
-
from flwr_datasets.partitioner import IidPartitioner
|
|
14
|
-
|
|
15
|
-
|
|
16
16
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
18
|
+
disable_progress_bar()
|
|
19
|
+
transformers.logging.set_verbosity_error()
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
fds = None # Cache FederatedDataset
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
def load_data(partition_id: int, num_partitions: int):
|
|
25
|
+
def load_data(partition_id: int, num_partitions: int, model_name: str):
|
|
25
26
|
"""Load IMDB data (training and eval)"""
|
|
26
27
|
# Only initialize `FederatedDataset` once
|
|
27
28
|
global fds
|
|
@@ -35,10 +36,12 @@ def load_data(partition_id: int, num_partitions: int):
|
|
|
35
36
|
# Divide data: 80% train, 20% test
|
|
36
37
|
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
|
|
37
38
|
|
|
38
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
39
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
39
40
|
|
|
40
41
|
def tokenize_function(examples):
|
|
41
|
-
return tokenizer(
|
|
42
|
+
return tokenizer(
|
|
43
|
+
examples["text"], truncation=True, add_special_tokens=True, max_length=512
|
|
44
|
+
)
|
|
42
45
|
|
|
43
46
|
partition_train_test = partition_train_test.map(tokenize_function, batched=True)
|
|
44
47
|
partition_train_test = partition_train_test.remove_columns("text")
|
|
@@ -59,12 +62,12 @@ def load_data(partition_id: int, num_partitions: int):
|
|
|
59
62
|
return trainloader, testloader
|
|
60
63
|
|
|
61
64
|
|
|
62
|
-
def train(net, trainloader, epochs):
|
|
65
|
+
def train(net, trainloader, epochs, device):
|
|
63
66
|
optimizer = AdamW(net.parameters(), lr=5e-5)
|
|
64
67
|
net.train()
|
|
65
68
|
for _ in range(epochs):
|
|
66
69
|
for batch in trainloader:
|
|
67
|
-
batch = {k: v.to(
|
|
70
|
+
batch = {k: v.to(device) for k, v in batch.items()}
|
|
68
71
|
outputs = net(**batch)
|
|
69
72
|
loss = outputs.loss
|
|
70
73
|
loss.backward()
|
|
@@ -72,12 +75,12 @@ def train(net, trainloader, epochs):
|
|
|
72
75
|
optimizer.zero_grad()
|
|
73
76
|
|
|
74
77
|
|
|
75
|
-
def test(net, testloader):
|
|
78
|
+
def test(net, testloader, device):
|
|
76
79
|
metric = load_metric("accuracy")
|
|
77
80
|
loss = 0
|
|
78
81
|
net.eval()
|
|
79
82
|
for batch in testloader:
|
|
80
|
-
batch = {k: v.to(
|
|
83
|
+
batch = {k: v.to(device) for k, v in batch.items()}
|
|
81
84
|
with torch.no_grad():
|
|
82
85
|
outputs = net(**batch)
|
|
83
86
|
logits = outputs.logits
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""$project_name: A Flower Baseline."""
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "$package_name"
|
|
7
|
+
version = "1.0.0"
|
|
8
|
+
description = ""
|
|
9
|
+
license = "Apache-2.0"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"flwr[simulation]>=1.11.0",
|
|
12
|
+
"flwr-datasets[vision]>=0.3.0",
|
|
13
|
+
"torch==2.2.1",
|
|
14
|
+
"torchvision==0.17.1",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[tool.hatch.metadata]
|
|
18
|
+
allow-direct-references = true
|
|
19
|
+
|
|
20
|
+
[project.optional-dependencies]
|
|
21
|
+
dev = [
|
|
22
|
+
"isort==5.13.2",
|
|
23
|
+
"black==24.2.0",
|
|
24
|
+
"docformatter==1.7.5",
|
|
25
|
+
"mypy==1.8.0",
|
|
26
|
+
"pylint==3.2.6",
|
|
27
|
+
"flake8==5.0.4",
|
|
28
|
+
"pytest==6.2.4",
|
|
29
|
+
"pytest-watch==4.2.0",
|
|
30
|
+
"ruff==0.1.9",
|
|
31
|
+
"types-requests==2.31.0.20240125",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
[tool.isort]
|
|
35
|
+
profile = "black"
|
|
36
|
+
known_first_party = ["flwr"]
|
|
37
|
+
|
|
38
|
+
[tool.black]
|
|
39
|
+
line-length = 88
|
|
40
|
+
target-version = ["py38", "py39", "py310", "py311"]
|
|
41
|
+
|
|
42
|
+
[tool.pytest.ini_options]
|
|
43
|
+
minversion = "6.2"
|
|
44
|
+
addopts = "-qq"
|
|
45
|
+
testpaths = [
|
|
46
|
+
"flwr_baselines",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
[tool.mypy]
|
|
50
|
+
ignore_missing_imports = true
|
|
51
|
+
strict = false
|
|
52
|
+
plugins = "numpy.typing.mypy_plugin"
|
|
53
|
+
|
|
54
|
+
[tool.pylint."MESSAGES CONTROL"]
|
|
55
|
+
disable = "duplicate-code,too-few-public-methods,useless-import-alias"
|
|
56
|
+
good-names = "i,j,k,_,x,y,X,Y,K,N"
|
|
57
|
+
max-args = 10
|
|
58
|
+
max-attributes = 15
|
|
59
|
+
max-locals = 36
|
|
60
|
+
max-branches = 20
|
|
61
|
+
max-statements = 55
|
|
62
|
+
|
|
63
|
+
[tool.pylint.typecheck]
|
|
64
|
+
generated-members = "numpy.*, torch.*, tensorflow.*"
|
|
65
|
+
|
|
66
|
+
[[tool.mypy.overrides]]
|
|
67
|
+
module = [
|
|
68
|
+
"importlib.metadata.*",
|
|
69
|
+
"importlib_metadata.*",
|
|
70
|
+
]
|
|
71
|
+
follow_imports = "skip"
|
|
72
|
+
follow_imports_for_stubs = true
|
|
73
|
+
disallow_untyped_calls = false
|
|
74
|
+
|
|
75
|
+
[[tool.mypy.overrides]]
|
|
76
|
+
module = "torch.*"
|
|
77
|
+
follow_imports = "skip"
|
|
78
|
+
follow_imports_for_stubs = true
|
|
79
|
+
|
|
80
|
+
[tool.docformatter]
|
|
81
|
+
wrap-summaries = 88
|
|
82
|
+
wrap-descriptions = 88
|
|
83
|
+
|
|
84
|
+
[tool.ruff]
|
|
85
|
+
target-version = "py38"
|
|
86
|
+
line-length = 88
|
|
87
|
+
select = ["D", "E", "F", "W", "B", "ISC", "C4"]
|
|
88
|
+
fixable = ["D", "E", "F", "W", "B", "ISC", "C4"]
|
|
89
|
+
ignore = ["B024", "B027"]
|
|
90
|
+
exclude = [
|
|
91
|
+
".bzr",
|
|
92
|
+
".direnv",
|
|
93
|
+
".eggs",
|
|
94
|
+
".git",
|
|
95
|
+
".hg",
|
|
96
|
+
".mypy_cache",
|
|
97
|
+
".nox",
|
|
98
|
+
".pants.d",
|
|
99
|
+
".pytype",
|
|
100
|
+
".ruff_cache",
|
|
101
|
+
".svn",
|
|
102
|
+
".tox",
|
|
103
|
+
".venv",
|
|
104
|
+
"__pypackages__",
|
|
105
|
+
"_build",
|
|
106
|
+
"buck-out",
|
|
107
|
+
"build",
|
|
108
|
+
"dist",
|
|
109
|
+
"node_modules",
|
|
110
|
+
"venv",
|
|
111
|
+
"proto",
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
[tool.ruff.pydocstyle]
|
|
115
|
+
convention = "numpy"
|
|
116
|
+
|
|
117
|
+
[tool.hatch.build.targets.wheel]
|
|
118
|
+
packages = ["."]
|
|
119
|
+
|
|
120
|
+
[tool.flwr.app]
|
|
121
|
+
publisher = "$username"
|
|
122
|
+
|
|
123
|
+
[tool.flwr.app.components]
|
|
124
|
+
serverapp = "$import_name.server_app:app"
|
|
125
|
+
clientapp = "$import_name.client_app:app"
|
|
126
|
+
|
|
127
|
+
[tool.flwr.app.config]
|
|
128
|
+
num-server-rounds = 3
|
|
129
|
+
fraction-fit = 0.5
|
|
130
|
+
local-epochs = 1
|
|
131
|
+
|
|
132
|
+
[tool.flwr.federations]
|
|
133
|
+
default = "local-simulation"
|
|
134
|
+
|
|
135
|
+
[tool.flwr.federations.local-simulation]
|
|
136
|
+
options.num-supernodes = 10
|
|
137
|
+
options.backend.client-resources.num-cpus = 2
|
|
138
|
+
options.backend.client-resources.num-gpus = 0.0
|
|
@@ -8,15 +8,15 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
12
|
-
"flwr-datasets>=0.
|
|
13
|
-
"hydra-core==1.3.2",
|
|
11
|
+
"flwr[simulation]>=1.10.0",
|
|
12
|
+
"flwr-datasets>=0.3.0",
|
|
14
13
|
"trl==0.8.1",
|
|
15
14
|
"bitsandbytes==0.43.0",
|
|
16
15
|
"scipy==1.13.0",
|
|
17
16
|
"peft==0.6.2",
|
|
18
17
|
"transformers==4.39.3",
|
|
19
18
|
"sentencepiece==0.2.0",
|
|
19
|
+
"omegaconf==2.3.0",
|
|
20
20
|
]
|
|
21
21
|
|
|
22
22
|
[tool.hatch.build.targets.wheel]
|
|
@@ -26,14 +26,41 @@ packages = ["."]
|
|
|
26
26
|
publisher = "$username"
|
|
27
27
|
|
|
28
28
|
[tool.flwr.app.components]
|
|
29
|
-
serverapp = "$import_name.app
|
|
30
|
-
clientapp = "$import_name.app
|
|
29
|
+
serverapp = "$import_name.server_app:app"
|
|
30
|
+
clientapp = "$import_name.client_app:app"
|
|
31
31
|
|
|
32
32
|
[tool.flwr.app.config]
|
|
33
|
-
|
|
33
|
+
model.name = "mistralai/Mistral-7B-v0.3"
|
|
34
|
+
model.quantization = 4
|
|
35
|
+
model.gradient-checkpointing = true
|
|
36
|
+
model.lora.peft-lora-r = 32
|
|
37
|
+
model.lora.peft-lora-alpha = 64
|
|
38
|
+
train.save-every-round = 5
|
|
39
|
+
train.learning-rate-max = 5e-5
|
|
40
|
+
train.learning-rate-min = 1e-6
|
|
41
|
+
train.seq-length = 512
|
|
42
|
+
train.training-arguments.output-dir = ""
|
|
43
|
+
train.training-arguments.learning-rate = ""
|
|
44
|
+
train.training-arguments.per-device-train-batch-size = 16
|
|
45
|
+
train.training-arguments.gradient-accumulation-steps = 1
|
|
46
|
+
train.training-arguments.logging-steps = 10
|
|
47
|
+
train.training-arguments.num-train-epochs = 3
|
|
48
|
+
train.training-arguments.max-steps = 10
|
|
49
|
+
train.training-arguments.save-steps = 1000
|
|
50
|
+
train.training-arguments.save-total-limit = 10
|
|
51
|
+
train.training-arguments.gradient-checkpointing = true
|
|
52
|
+
train.training-arguments.lr-scheduler-type = "constant"
|
|
53
|
+
strategy.fraction-fit = $fraction_fit
|
|
54
|
+
strategy.fraction-evaluate = 0.0
|
|
55
|
+
num-server-rounds = 200
|
|
56
|
+
|
|
57
|
+
[tool.flwr.app.config.static]
|
|
58
|
+
dataset.name = "$dataset_name"
|
|
34
59
|
|
|
35
60
|
[tool.flwr.federations]
|
|
36
61
|
default = "local-simulation"
|
|
37
62
|
|
|
38
63
|
[tool.flwr.federations.local-simulation]
|
|
39
|
-
options.num-supernodes =
|
|
64
|
+
options.num-supernodes = $num_clients
|
|
65
|
+
options.backend.client-resources.num-cpus = 6
|
|
66
|
+
options.backend.client-resources.num-gpus = 1.0
|
|
@@ -8,7 +8,7 @@ version = "1.0.0"
|
|
|
8
8
|
description = ""
|
|
9
9
|
license = "Apache-2.0"
|
|
10
10
|
dependencies = [
|
|
11
|
-
"flwr[simulation]>=1.
|
|
11
|
+
"flwr[simulation]>=1.11.0",
|
|
12
12
|
"flwr-datasets>=0.3.0",
|
|
13
13
|
"torch==2.2.1",
|
|
14
14
|
"transformers>=4.30.0,<5.0",
|
|
@@ -29,10 +29,18 @@ clientapp = "$import_name.client_app:app"
|
|
|
29
29
|
|
|
30
30
|
[tool.flwr.app.config]
|
|
31
31
|
num-server-rounds = 3
|
|
32
|
+
fraction-fit = 0.5
|
|
32
33
|
local-epochs = 1
|
|
34
|
+
model-name = "prajjwal1/bert-tiny" # Set a larger model if you have access to more GPU resources
|
|
35
|
+
num-labels = 2
|
|
33
36
|
|
|
34
37
|
[tool.flwr.federations]
|
|
35
38
|
default = "localhost"
|
|
36
39
|
|
|
37
40
|
[tool.flwr.federations.localhost]
|
|
38
41
|
options.num-supernodes = 10
|
|
42
|
+
|
|
43
|
+
[tool.flwr.federations.localhost-gpu]
|
|
44
|
+
options.num-supernodes = 10
|
|
45
|
+
options.backend.client-resources.num-cpus = 4 # each ClientApp assumes to use 4CPUs
|
|
46
|
+
options.backend.client-resources.num-gpus = 0.25 # at most 4 ClientApps will run in a given GPU
|
flwr/cli/run/run.py
CHANGED
|
@@ -124,14 +124,14 @@ def run(
|
|
|
124
124
|
|
|
125
125
|
|
|
126
126
|
def _run_with_superexec(
|
|
127
|
-
app:
|
|
127
|
+
app: Path,
|
|
128
128
|
federation_config: Dict[str, Any],
|
|
129
129
|
config_overrides: Optional[List[str]],
|
|
130
130
|
) -> None:
|
|
131
131
|
|
|
132
132
|
insecure_str = federation_config.get("insecure")
|
|
133
133
|
if root_certificates := federation_config.get("root-certificates"):
|
|
134
|
-
root_certificates_bytes =
|
|
134
|
+
root_certificates_bytes = (app / root_certificates).read_bytes()
|
|
135
135
|
if insecure := bool(insecure_str):
|
|
136
136
|
typer.secho(
|
|
137
137
|
"❌ `root_certificates` were provided but the `insecure` parameter"
|
flwr/client/__init__.py
CHANGED
|
@@ -20,8 +20,6 @@ from .app import start_numpy_client as start_numpy_client
|
|
|
20
20
|
from .client import Client as Client
|
|
21
21
|
from .client_app import ClientApp as ClientApp
|
|
22
22
|
from .numpy_client import NumPyClient as NumPyClient
|
|
23
|
-
from .supernode import run_client_app as run_client_app
|
|
24
|
-
from .supernode import run_supernode as run_supernode
|
|
25
23
|
from .typing import ClientFn as ClientFn
|
|
26
24
|
from .typing import ClientFnExt as ClientFnExt
|
|
27
25
|
|
|
@@ -32,8 +30,6 @@ __all__ = [
|
|
|
32
30
|
"ClientFnExt",
|
|
33
31
|
"NumPyClient",
|
|
34
32
|
"mod",
|
|
35
|
-
"run_client_app",
|
|
36
|
-
"run_supernode",
|
|
37
33
|
"start_client",
|
|
38
34
|
"start_numpy_client",
|
|
39
35
|
]
|
flwr/client/app.py
CHANGED
|
@@ -35,6 +35,7 @@ from flwr.client.typing import ClientFnExt
|
|
|
35
35
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
|
|
36
36
|
from flwr.common.address import parse_address
|
|
37
37
|
from flwr.common.constant import (
|
|
38
|
+
CLIENTAPPIO_API_DEFAULT_ADDRESS,
|
|
38
39
|
MISSING_EXTRA_REST,
|
|
39
40
|
RUN_ID_NUM_BYTES,
|
|
40
41
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
@@ -60,8 +61,6 @@ from .message_handler.message_handler import handle_control_message
|
|
|
60
61
|
from .node_state import NodeState
|
|
61
62
|
from .numpy_client import NumPyClient
|
|
62
63
|
|
|
63
|
-
ADDRESS_CLIENTAPPIO_API_GRPC_RERE = "0.0.0.0:9094"
|
|
64
|
-
|
|
65
64
|
ISOLATION_MODE_SUBPROCESS = "subprocess"
|
|
66
65
|
ISOLATION_MODE_PROCESS = "process"
|
|
67
66
|
|
|
@@ -211,7 +210,7 @@ def start_client_internal(
|
|
|
211
210
|
max_wait_time: Optional[float] = None,
|
|
212
211
|
flwr_path: Optional[Path] = None,
|
|
213
212
|
isolation: Optional[str] = None,
|
|
214
|
-
supernode_address: Optional[str] =
|
|
213
|
+
supernode_address: Optional[str] = CLIENTAPPIO_API_DEFAULT_ADDRESS,
|
|
215
214
|
) -> None:
|
|
216
215
|
"""Start a Flower client node which connects to a Flower server.
|
|
217
216
|
|
|
@@ -266,7 +265,7 @@ def start_client_internal(
|
|
|
266
265
|
by the SueprNode and communicates using gRPC at the address
|
|
267
266
|
`supernode_address`. If `process`, the `ClientApp` runs in a separate isolated
|
|
268
267
|
process and communicates using gRPC at the address `supernode_address`.
|
|
269
|
-
supernode_address : Optional[str] (default: `
|
|
268
|
+
supernode_address : Optional[str] (default: `CLIENTAPPIO_API_DEFAULT_ADDRESS`)
|
|
270
269
|
The SuperNode gRPC server address.
|
|
271
270
|
"""
|
|
272
271
|
if insecure is None:
|
flwr/client/client_app.py
CHANGED
|
@@ -41,11 +41,11 @@ def _alert_erroneous_client_fn() -> None:
|
|
|
41
41
|
|
|
42
42
|
def _inspect_maybe_adapt_client_fn_signature(client_fn: ClientFnExt) -> ClientFnExt:
|
|
43
43
|
client_fn_args = inspect.signature(client_fn).parameters
|
|
44
|
-
first_arg = list(client_fn_args.keys())[0]
|
|
45
44
|
|
|
46
45
|
if len(client_fn_args) != 1:
|
|
47
46
|
_alert_erroneous_client_fn()
|
|
48
47
|
|
|
48
|
+
first_arg = list(client_fn_args.keys())[0]
|
|
49
49
|
first_arg_type = client_fn_args[first_arg].annotation
|
|
50
50
|
|
|
51
51
|
if first_arg_type is str or first_arg == "cid":
|
|
@@ -263,7 +263,7 @@ def _registration_error(fn_name: str) -> ValueError:
|
|
|
263
263
|
>>> class FlowerClient(NumPyClient):
|
|
264
264
|
>>> # ...
|
|
265
265
|
>>>
|
|
266
|
-
>>> def client_fn(
|
|
266
|
+
>>> def client_fn(context: Context):
|
|
267
267
|
>>> return FlowerClient().to_client()
|
|
268
268
|
>>>
|
|
269
269
|
>>> app = ClientApp(
|
|
@@ -17,11 +17,13 @@
|
|
|
17
17
|
|
|
18
18
|
import base64
|
|
19
19
|
import collections
|
|
20
|
+
from logging import WARNING
|
|
20
21
|
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
|
21
22
|
|
|
22
23
|
import grpc
|
|
23
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
25
|
|
|
26
|
+
from flwr.common.logger import log
|
|
25
27
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
26
28
|
bytes_to_public_key,
|
|
27
29
|
compute_hmac,
|
|
@@ -128,13 +130,12 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
|
|
|
128
130
|
if self.shared_secret is None:
|
|
129
131
|
raise RuntimeError("Failure to compute hmac")
|
|
130
132
|
|
|
133
|
+
message_bytes = request.SerializeToString(deterministic=True)
|
|
131
134
|
metadata.append(
|
|
132
135
|
(
|
|
133
136
|
_AUTH_TOKEN_HEADER,
|
|
134
137
|
base64.urlsafe_b64encode(
|
|
135
|
-
compute_hmac(
|
|
136
|
-
self.shared_secret, request.SerializeToString(True)
|
|
137
|
-
)
|
|
138
|
+
compute_hmac(self.shared_secret, message_bytes)
|
|
138
139
|
),
|
|
139
140
|
)
|
|
140
141
|
)
|
|
@@ -151,8 +152,15 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
|
|
|
151
152
|
server_public_key_bytes = base64.urlsafe_b64decode(
|
|
152
153
|
_get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
|
|
153
154
|
)
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
self.
|
|
157
|
-
|
|
155
|
+
|
|
156
|
+
if server_public_key_bytes != b"":
|
|
157
|
+
self.server_public_key = bytes_to_public_key(server_public_key_bytes)
|
|
158
|
+
else:
|
|
159
|
+
log(WARNING, "Can't get server public key, SuperLink may be offline")
|
|
160
|
+
|
|
161
|
+
if self.server_public_key is not None:
|
|
162
|
+
self.shared_secret = generate_shared_key(
|
|
163
|
+
self.private_key, self.server_public_key
|
|
164
|
+
)
|
|
165
|
+
|
|
158
166
|
return response
|
flwr/client/supernode/app.py
CHANGED
|
@@ -30,6 +30,7 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
30
30
|
from flwr.common import EventType, event
|
|
31
31
|
from flwr.common.config import parse_config_args
|
|
32
32
|
from flwr.common.constant import (
|
|
33
|
+
FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
|
|
33
34
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
34
35
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
35
36
|
TRANSPORT_TYPE_REST,
|
|
@@ -44,8 +45,6 @@ from ..app import (
|
|
|
44
45
|
)
|
|
45
46
|
from ..clientapp.utils import get_load_client_app_fn
|
|
46
47
|
|
|
47
|
-
ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092"
|
|
48
|
-
|
|
49
48
|
|
|
50
49
|
def run_supernode() -> None:
|
|
51
50
|
"""Run Flower SuperNode."""
|
|
@@ -77,7 +76,9 @@ def run_supernode() -> None:
|
|
|
77
76
|
authentication_keys=authentication_keys,
|
|
78
77
|
max_retries=args.max_retries,
|
|
79
78
|
max_wait_time=args.max_wait_time,
|
|
80
|
-
node_config=parse_config_args(
|
|
79
|
+
node_config=parse_config_args(
|
|
80
|
+
[args.node_config] if args.node_config else args.node_config
|
|
81
|
+
),
|
|
81
82
|
isolation=args.isolation,
|
|
82
83
|
supernode_address=args.supernode_address,
|
|
83
84
|
)
|
|
@@ -101,11 +102,11 @@ def run_client_app() -> None:
|
|
|
101
102
|
|
|
102
103
|
def _warn_deprecated_server_arg(args: argparse.Namespace) -> None:
|
|
103
104
|
"""Warn about the deprecated argument `--server`."""
|
|
104
|
-
if args.server !=
|
|
105
|
+
if args.server != FLEET_API_GRPC_RERE_DEFAULT_ADDRESS:
|
|
105
106
|
warn = "Passing flag --server is deprecated. Use --superlink instead."
|
|
106
107
|
warn_deprecated_feature(warn)
|
|
107
108
|
|
|
108
|
-
if args.superlink !=
|
|
109
|
+
if args.superlink != FLEET_API_GRPC_RERE_DEFAULT_ADDRESS:
|
|
109
110
|
# if `--superlink` also passed, then
|
|
110
111
|
# warn user that this argument overrides what was passed with `--server`
|
|
111
112
|
log(
|
|
@@ -245,12 +246,12 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
245
246
|
)
|
|
246
247
|
parser.add_argument(
|
|
247
248
|
"--server",
|
|
248
|
-
default=
|
|
249
|
+
default=FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
|
|
249
250
|
help="Server address",
|
|
250
251
|
)
|
|
251
252
|
parser.add_argument(
|
|
252
253
|
"--superlink",
|
|
253
|
-
default=
|
|
254
|
+
default=FLEET_API_GRPC_RERE_DEFAULT_ADDRESS,
|
|
254
255
|
help="SuperLink Fleet API (gRPC-rere) address (IPv4, IPv6, or a domain name)",
|
|
255
256
|
)
|
|
256
257
|
parser.add_argument(
|
flwr/common/config.py
CHANGED
|
@@ -185,23 +185,26 @@ def parse_config_args(
|
|
|
185
185
|
if config is None:
|
|
186
186
|
return overrides
|
|
187
187
|
|
|
188
|
+
# Handle if .toml file is passed
|
|
189
|
+
if len(config) == 1 and config[0].endswith(".toml"):
|
|
190
|
+
with Path(config[0]).open("rb") as config_file:
|
|
191
|
+
overrides = flatten_dict(tomli.load(config_file))
|
|
192
|
+
return overrides
|
|
193
|
+
|
|
188
194
|
# Regular expression to capture key-value pairs with possible quoted values
|
|
189
195
|
pattern = re.compile(r"(\S+?)=(\'[^\']*\'|\"[^\"]*\"|\S+)")
|
|
190
196
|
|
|
191
197
|
for config_line in config:
|
|
192
198
|
if config_line:
|
|
193
|
-
|
|
199
|
+
# .toml files aren't allowed alongside other configs
|
|
200
|
+
if config_line.endswith(".toml"):
|
|
201
|
+
raise ValueError(
|
|
202
|
+
"TOML files cannot be passed alongside key-value pairs."
|
|
203
|
+
)
|
|
194
204
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
and matches[0][0].endswith(".toml")
|
|
199
|
-
):
|
|
200
|
-
with Path(matches[0][0]).open("rb") as config_file:
|
|
201
|
-
overrides = flatten_dict(tomli.load(config_file))
|
|
202
|
-
else:
|
|
203
|
-
toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
|
|
204
|
-
overrides.update(tomli.loads(toml_str))
|
|
205
|
+
matches = pattern.findall(config_line)
|
|
206
|
+
toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
|
|
207
|
+
overrides.update(tomli.loads(toml_str))
|
|
205
208
|
|
|
206
209
|
return overrides
|
|
207
210
|
|
flwr/common/constant.py
CHANGED
|
@@ -37,7 +37,18 @@ TRANSPORT_TYPES = [
|
|
|
37
37
|
TRANSPORT_TYPE_VCE,
|
|
38
38
|
]
|
|
39
39
|
|
|
40
|
-
|
|
40
|
+
# Addresses
|
|
41
|
+
# SuperNode
|
|
42
|
+
CLIENTAPPIO_API_DEFAULT_ADDRESS = "0.0.0.0:9094"
|
|
43
|
+
# SuperExec
|
|
44
|
+
EXEC_API_DEFAULT_ADDRESS = "0.0.0.0:9093"
|
|
45
|
+
# SuperLink
|
|
46
|
+
DRIVER_API_DEFAULT_ADDRESS = "0.0.0.0:9091"
|
|
47
|
+
FLEET_API_GRPC_RERE_DEFAULT_ADDRESS = "0.0.0.0:9092"
|
|
48
|
+
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS = (
|
|
49
|
+
"[::]:8080" # IPv6 to keep start_server compatible
|
|
50
|
+
)
|
|
51
|
+
FLEET_API_REST_DEFAULT_ADDRESS = "0.0.0.0:9093"
|
|
41
52
|
|
|
42
53
|
# Constants for ping
|
|
43
54
|
PING_DEFAULT_INTERVAL = 30
|
flwr/common/record/recordset.py
CHANGED
|
@@ -119,7 +119,7 @@ class RecordSet:
|
|
|
119
119
|
Let's see an example.
|
|
120
120
|
|
|
121
121
|
>>> from flwr.common import RecordSet
|
|
122
|
-
>>> from flwr.common import
|
|
122
|
+
>>> from flwr.common import ConfigsRecord, MetricsRecord, ParametersRecord
|
|
123
123
|
>>>
|
|
124
124
|
>>> # Let's begin with an empty record
|
|
125
125
|
>>> my_recordset = RecordSet()
|
flwr/common/record/typeddict.py
CHANGED
|
@@ -15,7 +15,18 @@
|
|
|
15
15
|
"""Typed dict base class for *Records."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from typing import
|
|
18
|
+
from typing import (
|
|
19
|
+
Callable,
|
|
20
|
+
Dict,
|
|
21
|
+
Generic,
|
|
22
|
+
ItemsView,
|
|
23
|
+
Iterator,
|
|
24
|
+
KeysView,
|
|
25
|
+
MutableMapping,
|
|
26
|
+
TypeVar,
|
|
27
|
+
ValuesView,
|
|
28
|
+
cast,
|
|
29
|
+
)
|
|
19
30
|
|
|
20
31
|
K = TypeVar("K") # Key type
|
|
21
32
|
V = TypeVar("V") # Value type
|
|
@@ -73,3 +84,15 @@ class TypedDict(MutableMapping[K, V], Generic[K, V]):
|
|
|
73
84
|
if isinstance(other, dict):
|
|
74
85
|
return data == other
|
|
75
86
|
return NotImplemented
|
|
87
|
+
|
|
88
|
+
def keys(self) -> KeysView[K]:
|
|
89
|
+
"""D.keys() -> a set-like object providing a view on D's keys."""
|
|
90
|
+
return cast(Dict[K, V], self.__dict__["_data"]).keys()
|
|
91
|
+
|
|
92
|
+
def values(self) -> ValuesView[V]:
|
|
93
|
+
"""D.values() -> an object providing a view on D's values."""
|
|
94
|
+
return cast(Dict[K, V], self.__dict__["_data"]).values()
|
|
95
|
+
|
|
96
|
+
def items(self) -> ItemsView[K, V]:
|
|
97
|
+
"""D.items() -> a set-like object providing a view on D's items."""
|
|
98
|
+
return cast(Dict[K, V], self.__dict__["_data"]).items()
|