flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240507__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +151 -0
- flwr/cli/config_utils.py +18 -46
- flwr/cli/new/new.py +42 -18
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +70 -0
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +15 -29
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +15 -0
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +9 -1
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +89 -0
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +29 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +28 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +27 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -4
- flwr/cli/run/run.py +1 -1
- flwr/cli/utils.py +18 -17
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +17 -93
- flwr/client/grpc_client/connection.py +6 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +17 -2
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/localdp_mod.py +9 -3
- flwr/client/rest_client/connection.py +5 -1
- flwr/client/supernode/__init__.py +2 -0
- flwr/client/supernode/app.py +181 -7
- flwr/common/grpc.py +5 -1
- flwr/common/logger.py +37 -4
- flwr/common/message.py +105 -86
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +17 -5
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +35 -1
- flwr/server/app.py +111 -1
- flwr/server/compat/app.py +2 -2
- flwr/server/compat/app_utils.py +1 -1
- flwr/server/compat/driver_client_proxy.py +27 -72
- flwr/server/driver/__init__.py +3 -0
- flwr/server/driver/driver.py +12 -242
- flwr/server/driver/grpc_driver.py +315 -0
- flwr/server/run_serverapp.py +18 -4
- flwr/server/strategy/dp_adaptive_clipping.py +5 -3
- flwr/server/strategy/dp_fixed_clipping.py +6 -3
- flwr/server/superlink/driver/driver_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +3 -1
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +215 -0
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +76 -8
- flwr/server/superlink/state/sqlite_state.py +116 -11
- flwr/server/superlink/state/state.py +35 -3
- flwr/simulation/__init__.py +2 -2
- flwr/simulation/app.py +16 -1
- flwr/simulation/run_simulation.py +10 -7
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/METADATA +3 -2
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/RECORD +63 -52
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/entry_points.txt +1 -1
- flwr/server/driver/abc_driver.py +0 -140
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240507.dist-info}/WHEEL +0 -0
|
@@ -1,18 +1,26 @@
|
|
|
1
1
|
"""$project_name: A Flower / TensorFlow app."""
|
|
2
2
|
|
|
3
|
+
from flwr.common import ndarrays_to_parameters
|
|
3
4
|
from flwr.server import ServerApp, ServerConfig
|
|
4
5
|
from flwr.server.strategy import FedAvg
|
|
5
6
|
|
|
7
|
+
from $import_name.task import load_model
|
|
8
|
+
|
|
6
9
|
# Define config
|
|
7
10
|
config = ServerConfig(num_rounds=3)
|
|
8
11
|
|
|
12
|
+
parameters = ndarrays_to_parameters(load_model().get_weights())
|
|
13
|
+
|
|
14
|
+
# Define strategy
|
|
9
15
|
strategy = FedAvg(
|
|
10
16
|
fraction_fit=1.0,
|
|
11
17
|
fraction_evaluate=1.0,
|
|
12
18
|
min_available_clients=2,
|
|
19
|
+
initial_parameters=parameters,
|
|
13
20
|
)
|
|
14
21
|
|
|
15
|
-
|
|
22
|
+
|
|
23
|
+
# Create ServerApp
|
|
16
24
|
app = ServerApp(
|
|
17
25
|
config=config,
|
|
18
26
|
strategy=strategy,
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""$project_name: A Flower / MLX app."""
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
import numpy as np
|
|
6
|
+
from datasets.utils.logging import disable_progress_bar
|
|
7
|
+
from flwr_datasets import FederatedDataset
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
disable_progress_bar()
|
|
11
|
+
|
|
12
|
+
class MLP(nn.Module):
|
|
13
|
+
"""A simple MLP."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
|
17
|
+
):
|
|
18
|
+
super().__init__()
|
|
19
|
+
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
|
20
|
+
self.layers = [
|
|
21
|
+
nn.Linear(idim, odim)
|
|
22
|
+
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
def __call__(self, x):
|
|
26
|
+
for l in self.layers[:-1]:
|
|
27
|
+
x = mx.maximum(l(x), 0.0)
|
|
28
|
+
return self.layers[-1](x)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def loss_fn(model, X, y):
|
|
32
|
+
return mx.mean(nn.losses.cross_entropy(model(X), y))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def eval_fn(model, X, y):
|
|
36
|
+
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def batch_iterate(batch_size, X, y):
|
|
40
|
+
perm = mx.array(np.random.permutation(y.size))
|
|
41
|
+
for s in range(0, y.size, batch_size):
|
|
42
|
+
ids = perm[s : s + batch_size]
|
|
43
|
+
yield X[ids], y[ids]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def load_data(partition_id, num_clients):
|
|
47
|
+
fds = FederatedDataset(dataset="mnist", partitioners={"train": num_clients})
|
|
48
|
+
partition = fds.load_partition(partition_id)
|
|
49
|
+
partition_splits = partition.train_test_split(test_size=0.2, seed=42)
|
|
50
|
+
|
|
51
|
+
partition_splits["train"].set_format("numpy")
|
|
52
|
+
partition_splits["test"].set_format("numpy")
|
|
53
|
+
|
|
54
|
+
train_partition = partition_splits["train"].map(
|
|
55
|
+
lambda img: {
|
|
56
|
+
"img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
|
|
57
|
+
},
|
|
58
|
+
input_columns="image",
|
|
59
|
+
)
|
|
60
|
+
test_partition = partition_splits["test"].map(
|
|
61
|
+
lambda img: {
|
|
62
|
+
"img": img.reshape(-1, 28 * 28).squeeze().astype(np.float32) / 255.0
|
|
63
|
+
},
|
|
64
|
+
input_columns="image",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
data = (
|
|
68
|
+
train_partition["img"],
|
|
69
|
+
train_partition["label"].astype(np.uint32),
|
|
70
|
+
test_partition["img"],
|
|
71
|
+
test_partition["label"].astype(np.uint32),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
train_images, train_labels, test_images, test_labels = map(mx.array, data)
|
|
75
|
+
return train_images, train_labels, test_images, test_labels
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def get_params(model):
|
|
79
|
+
layers = model.parameters()["layers"]
|
|
80
|
+
return [np.array(val) for layer in layers for _, val in layer.items()]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def set_params(model, parameters):
|
|
84
|
+
new_params = {}
|
|
85
|
+
new_params["layers"] = [
|
|
86
|
+
{"weight": mx.array(parameters[i]), "bias": mx.array(parameters[i + 1])}
|
|
87
|
+
for i in range(0, len(parameters), 2)
|
|
88
|
+
]
|
|
89
|
+
model.update(new_params)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""$project_name: A Flower / TensorFlow app."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
from flwr_datasets import FederatedDataset
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Make TensorFlow log less verbose
|
|
10
|
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
11
|
+
|
|
12
|
+
def load_model():
|
|
13
|
+
# Load model and data (MobileNetV2, CIFAR-10)
|
|
14
|
+
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
|
|
15
|
+
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
|
|
16
|
+
return model
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def load_data(partition_id, num_partitions):
|
|
20
|
+
# Download and partition dataset
|
|
21
|
+
fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
|
|
22
|
+
partition = fds.load_partition(partition_id, "train")
|
|
23
|
+
partition.set_format("numpy")
|
|
24
|
+
|
|
25
|
+
# Divide data on each node: 80% train, 20% test
|
|
26
|
+
partition = partition.train_test_split(test_size=0.2)
|
|
27
|
+
x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"]
|
|
28
|
+
x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"]
|
|
29
|
+
return x_train, y_train, x_test, y_test
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "$package_name"
|
|
7
|
+
version = "1.0.0"
|
|
8
|
+
description = ""
|
|
9
|
+
authors = [
|
|
10
|
+
{ name = "The Flower Authors", email = "hello@flower.ai" },
|
|
11
|
+
]
|
|
12
|
+
license = { text = "Apache License (2.0)" }
|
|
13
|
+
dependencies = [
|
|
14
|
+
"flwr[simulation]>=1.8.0,<2.0",
|
|
15
|
+
"flwr-datasets[vision]>=0.0.2,<1.0.0",
|
|
16
|
+
"mlx==0.10.0",
|
|
17
|
+
"numpy==1.24.4",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
[tool.hatch.build.targets.wheel]
|
|
21
|
+
packages = ["."]
|
|
22
|
+
|
|
23
|
+
[flower]
|
|
24
|
+
publisher = "$username"
|
|
25
|
+
|
|
26
|
+
[flower.components]
|
|
27
|
+
serverapp = "$import_name.server:app"
|
|
28
|
+
clientapp = "$import_name.client:app"
|
|
@@ -3,13 +3,13 @@ requires = ["hatchling"]
|
|
|
3
3
|
build-backend = "hatchling.build"
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
|
-
name = "$
|
|
6
|
+
name = "$package_name"
|
|
7
7
|
version = "1.0.0"
|
|
8
8
|
description = ""
|
|
9
9
|
authors = [
|
|
10
10
|
{ name = "The Flower Authors", email = "hello@flower.ai" },
|
|
11
11
|
]
|
|
12
|
-
license = {text = "Apache License (2.0)"}
|
|
12
|
+
license = { text = "Apache License (2.0)" }
|
|
13
13
|
dependencies = [
|
|
14
14
|
"flwr[simulation]>=1.8.0,<2.0",
|
|
15
15
|
"numpy>=1.21.0",
|
|
@@ -18,6 +18,9 @@ dependencies = [
|
|
|
18
18
|
[tool.hatch.build.targets.wheel]
|
|
19
19
|
packages = ["."]
|
|
20
20
|
|
|
21
|
+
[flower]
|
|
22
|
+
publisher = "$username"
|
|
23
|
+
|
|
21
24
|
[flower.components]
|
|
22
|
-
serverapp = "$
|
|
23
|
-
clientapp = "$
|
|
25
|
+
serverapp = "$import_name.server:app"
|
|
26
|
+
clientapp = "$import_name.client:app"
|
|
@@ -3,13 +3,13 @@ requires = ["hatchling"]
|
|
|
3
3
|
build-backend = "hatchling.build"
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
|
-
name = "$
|
|
6
|
+
name = "$package_name"
|
|
7
7
|
version = "1.0.0"
|
|
8
8
|
description = ""
|
|
9
9
|
authors = [
|
|
10
10
|
{ name = "The Flower Authors", email = "hello@flower.ai" },
|
|
11
11
|
]
|
|
12
|
-
license = {text = "Apache License (2.0)"}
|
|
12
|
+
license = { text = "Apache License (2.0)" }
|
|
13
13
|
dependencies = [
|
|
14
14
|
"flwr[simulation]>=1.8.0,<2.0",
|
|
15
15
|
"flwr-datasets[vision]>=0.0.2,<1.0.0",
|
|
@@ -20,6 +20,9 @@ dependencies = [
|
|
|
20
20
|
[tool.hatch.build.targets.wheel]
|
|
21
21
|
packages = ["."]
|
|
22
22
|
|
|
23
|
+
[flower]
|
|
24
|
+
publisher = "$username"
|
|
25
|
+
|
|
23
26
|
[flower.components]
|
|
24
|
-
serverapp = "$
|
|
25
|
-
clientapp = "$
|
|
27
|
+
serverapp = "$import_name.server:app"
|
|
28
|
+
clientapp = "$import_name.client:app"
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "$package_name"
|
|
7
|
+
version = "1.0.0"
|
|
8
|
+
description = ""
|
|
9
|
+
authors = [
|
|
10
|
+
{ name = "The Flower Authors", email = "hello@flower.ai" },
|
|
11
|
+
]
|
|
12
|
+
license = { text = "Apache License (2.0)" }
|
|
13
|
+
dependencies = [
|
|
14
|
+
"flwr[simulation]>=1.8.0,<2.0",
|
|
15
|
+
"flwr-datasets[vision]>=0.0.2,<1.0.0",
|
|
16
|
+
"scikit-learn>=1.1.1",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
[tool.hatch.build.targets.wheel]
|
|
20
|
+
packages = ["."]
|
|
21
|
+
|
|
22
|
+
[flower]
|
|
23
|
+
publisher = "$username"
|
|
24
|
+
|
|
25
|
+
[flower.components]
|
|
26
|
+
serverapp = "$import_name.server:app"
|
|
27
|
+
clientapp = "$import_name.client:app"
|
|
@@ -3,13 +3,13 @@ requires = ["hatchling"]
|
|
|
3
3
|
build-backend = "hatchling.build"
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
|
-
name = "$
|
|
6
|
+
name = "$package_name"
|
|
7
7
|
version = "1.0.0"
|
|
8
8
|
description = ""
|
|
9
9
|
authors = [
|
|
10
10
|
{ name = "The Flower Authors", email = "hello@flower.ai" },
|
|
11
11
|
]
|
|
12
|
-
license = {text = "Apache License (2.0)"}
|
|
12
|
+
license = { text = "Apache License (2.0)" }
|
|
13
13
|
dependencies = [
|
|
14
14
|
"flwr[simulation]>=1.8.0,<2.0",
|
|
15
15
|
"flwr-datasets[vision]>=0.0.2,<1.0.0",
|
|
@@ -19,6 +19,9 @@ dependencies = [
|
|
|
19
19
|
[tool.hatch.build.targets.wheel]
|
|
20
20
|
packages = ["."]
|
|
21
21
|
|
|
22
|
+
[flower]
|
|
23
|
+
publisher = "$username"
|
|
24
|
+
|
|
22
25
|
[flower.components]
|
|
23
|
-
serverapp = "$
|
|
24
|
-
clientapp = "$
|
|
26
|
+
serverapp = "$import_name.server:app"
|
|
27
|
+
clientapp = "$import_name.client:app"
|
flwr/cli/run/run.py
CHANGED
|
@@ -30,7 +30,7 @@ def run() -> None:
|
|
|
30
30
|
|
|
31
31
|
if config is None:
|
|
32
32
|
typer.secho(
|
|
33
|
-
"Project configuration could not be loaded.\
|
|
33
|
+
"Project configuration could not be loaded.\npyproject.toml is invalid:\n"
|
|
34
34
|
+ "\n".join([f"- {line}" for line in errors]),
|
|
35
35
|
fg=typer.colors.RED,
|
|
36
36
|
bold=True,
|
flwr/cli/utils.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower command line interface utils."""
|
|
16
16
|
|
|
17
|
+
import re
|
|
17
18
|
from typing import Callable, List, Optional, cast
|
|
18
19
|
|
|
19
20
|
import typer
|
|
@@ -73,51 +74,51 @@ def prompt_options(text: str, options: List[str]) -> str:
|
|
|
73
74
|
|
|
74
75
|
|
|
75
76
|
def is_valid_project_name(name: str) -> bool:
|
|
76
|
-
"""Check if the given string is a valid Python
|
|
77
|
+
"""Check if the given string is a valid Python project name.
|
|
77
78
|
|
|
78
|
-
A valid
|
|
79
|
-
|
|
79
|
+
A valid project name must start with a letter and can only contain letters, digits,
|
|
80
|
+
and hyphens.
|
|
80
81
|
"""
|
|
81
82
|
if not name:
|
|
82
83
|
return False
|
|
83
84
|
|
|
84
|
-
# Check if the first character is a letter
|
|
85
|
-
if not
|
|
85
|
+
# Check if the first character is a letter
|
|
86
|
+
if not name[0].isalpha():
|
|
86
87
|
return False
|
|
87
88
|
|
|
88
|
-
# Check if the rest of the characters are valid (letter, digit, or
|
|
89
|
+
# Check if the rest of the characters are valid (letter, digit, or dash)
|
|
89
90
|
for char in name[1:]:
|
|
90
|
-
if not (char.isalnum() or char
|
|
91
|
+
if not (char.isalnum() or char in "-"):
|
|
91
92
|
return False
|
|
92
93
|
|
|
93
94
|
return True
|
|
94
95
|
|
|
95
96
|
|
|
96
97
|
def sanitize_project_name(name: str) -> str:
|
|
97
|
-
"""Sanitize the given string to make it a valid Python
|
|
98
|
+
"""Sanitize the given string to make it a valid Python project name.
|
|
98
99
|
|
|
99
|
-
This version replaces
|
|
100
|
-
in Python
|
|
101
|
-
valid character.
|
|
100
|
+
This version replaces spaces, dots, slashes, and underscores with dashes, removes
|
|
101
|
+
any characters not allowed in Python project names, makes the string lowercase, and
|
|
102
|
+
ensures it starts with a valid character.
|
|
102
103
|
"""
|
|
103
|
-
# Replace
|
|
104
|
-
|
|
104
|
+
# Replace whitespace with '_'
|
|
105
|
+
name_with_hyphens = re.sub(r"[ ./_]", "-", name)
|
|
105
106
|
|
|
106
107
|
# Allowed characters in a module name: letters, digits, underscore
|
|
107
108
|
allowed_chars = set(
|
|
108
|
-
"
|
|
109
|
+
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-"
|
|
109
110
|
)
|
|
110
111
|
|
|
111
112
|
# Make the string lowercase
|
|
112
|
-
sanitized_name =
|
|
113
|
+
sanitized_name = name_with_hyphens.lower()
|
|
113
114
|
|
|
114
115
|
# Remove any characters not allowed in Python module names
|
|
115
116
|
sanitized_name = "".join(c for c in sanitized_name if c in allowed_chars)
|
|
116
117
|
|
|
117
118
|
# Ensure the first character is a letter or underscore
|
|
118
|
-
|
|
119
|
+
while sanitized_name and (
|
|
119
120
|
sanitized_name[0].isdigit() or sanitized_name[0] not in allowed_chars
|
|
120
121
|
):
|
|
121
|
-
sanitized_name =
|
|
122
|
+
sanitized_name = sanitized_name[1:]
|
|
122
123
|
|
|
123
124
|
return sanitized_name
|
flwr/client/__init__.py
CHANGED
|
@@ -15,12 +15,12 @@
|
|
|
15
15
|
"""Flower client."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from .app import run_client_app as run_client_app
|
|
19
18
|
from .app import start_client as start_client
|
|
20
19
|
from .app import start_numpy_client as start_numpy_client
|
|
21
20
|
from .client import Client as Client
|
|
22
21
|
from .client_app import ClientApp as ClientApp
|
|
23
22
|
from .numpy_client import NumPyClient as NumPyClient
|
|
23
|
+
from .supernode import run_client_app as run_client_app
|
|
24
24
|
from .supernode import run_supernode as run_supernode
|
|
25
25
|
from .typing import ClientFn as ClientFn
|
|
26
26
|
|
flwr/client/app.py
CHANGED
|
@@ -14,13 +14,12 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Flower client app."""
|
|
16
16
|
|
|
17
|
-
import argparse
|
|
18
17
|
import sys
|
|
19
18
|
import time
|
|
20
19
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
21
|
-
from pathlib import Path
|
|
22
20
|
from typing import Callable, ContextManager, Optional, Tuple, Type, Union
|
|
23
21
|
|
|
22
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
24
23
|
from grpc import RpcError
|
|
25
24
|
|
|
26
25
|
from flwr.client.client import Client
|
|
@@ -36,10 +35,8 @@ from flwr.common.constant import (
|
|
|
36
35
|
TRANSPORT_TYPES,
|
|
37
36
|
ErrorCode,
|
|
38
37
|
)
|
|
39
|
-
from flwr.common.exit_handlers import register_exit_handlers
|
|
40
38
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
41
39
|
from flwr.common.message import Error
|
|
42
|
-
from flwr.common.object_ref import load_app, validate
|
|
43
40
|
from flwr.common.retry_invoker import RetryInvoker, exponential
|
|
44
41
|
|
|
45
42
|
from .grpc_client.connection import grpc_connection
|
|
@@ -47,94 +44,6 @@ from .grpc_rere_client.connection import grpc_request_response
|
|
|
47
44
|
from .message_handler.message_handler import handle_control_message
|
|
48
45
|
from .node_state import NodeState
|
|
49
46
|
from .numpy_client import NumPyClient
|
|
50
|
-
from .supernode.app import parse_args_run_client_app
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def run_client_app() -> None:
|
|
54
|
-
"""Run Flower client app."""
|
|
55
|
-
log(INFO, "Long-running Flower client starting")
|
|
56
|
-
|
|
57
|
-
event(EventType.RUN_CLIENT_APP_ENTER)
|
|
58
|
-
|
|
59
|
-
args = _parse_args_run_client_app().parse_args()
|
|
60
|
-
|
|
61
|
-
# Obtain certificates
|
|
62
|
-
if args.insecure:
|
|
63
|
-
if args.root_certificates is not None:
|
|
64
|
-
sys.exit(
|
|
65
|
-
"Conflicting options: The '--insecure' flag disables HTTPS, "
|
|
66
|
-
"but '--root-certificates' was also specified. Please remove "
|
|
67
|
-
"the '--root-certificates' option when running in insecure mode, "
|
|
68
|
-
"or omit '--insecure' to use HTTPS."
|
|
69
|
-
)
|
|
70
|
-
log(
|
|
71
|
-
WARN,
|
|
72
|
-
"Option `--insecure` was set. "
|
|
73
|
-
"Starting insecure HTTP client connected to %s.",
|
|
74
|
-
args.server,
|
|
75
|
-
)
|
|
76
|
-
root_certificates = None
|
|
77
|
-
else:
|
|
78
|
-
# Load the certificates if provided, or load the system certificates
|
|
79
|
-
cert_path = args.root_certificates
|
|
80
|
-
if cert_path is None:
|
|
81
|
-
root_certificates = None
|
|
82
|
-
else:
|
|
83
|
-
root_certificates = Path(cert_path).read_bytes()
|
|
84
|
-
log(
|
|
85
|
-
DEBUG,
|
|
86
|
-
"Starting secure HTTPS client connected to %s "
|
|
87
|
-
"with the following certificates: %s.",
|
|
88
|
-
args.server,
|
|
89
|
-
cert_path,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
log(
|
|
93
|
-
DEBUG,
|
|
94
|
-
"Flower will load ClientApp `%s`",
|
|
95
|
-
getattr(args, "client-app"),
|
|
96
|
-
)
|
|
97
|
-
|
|
98
|
-
client_app_dir = args.dir
|
|
99
|
-
if client_app_dir is not None:
|
|
100
|
-
sys.path.insert(0, client_app_dir)
|
|
101
|
-
|
|
102
|
-
app_ref: str = getattr(args, "client-app")
|
|
103
|
-
valid, error_msg = validate(app_ref)
|
|
104
|
-
if not valid and error_msg:
|
|
105
|
-
raise LoadClientAppError(error_msg) from None
|
|
106
|
-
|
|
107
|
-
def _load() -> ClientApp:
|
|
108
|
-
client_app = load_app(app_ref, LoadClientAppError)
|
|
109
|
-
|
|
110
|
-
if not isinstance(client_app, ClientApp):
|
|
111
|
-
raise LoadClientAppError(
|
|
112
|
-
f"Attribute {app_ref} is not of type {ClientApp}",
|
|
113
|
-
) from None
|
|
114
|
-
|
|
115
|
-
return client_app
|
|
116
|
-
|
|
117
|
-
_start_client_internal(
|
|
118
|
-
server_address=args.server,
|
|
119
|
-
load_client_app_fn=_load,
|
|
120
|
-
transport="rest" if args.rest else "grpc-rere",
|
|
121
|
-
root_certificates=root_certificates,
|
|
122
|
-
insecure=args.insecure,
|
|
123
|
-
max_retries=args.max_retries,
|
|
124
|
-
max_wait_time=args.max_wait_time,
|
|
125
|
-
)
|
|
126
|
-
register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
def _parse_args_run_client_app() -> argparse.ArgumentParser:
|
|
130
|
-
"""Parse flower-client-app command line arguments."""
|
|
131
|
-
parser = argparse.ArgumentParser(
|
|
132
|
-
description="Start a Flower client app",
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
parse_args_run_client_app(parser=parser)
|
|
136
|
-
|
|
137
|
-
return parser
|
|
138
47
|
|
|
139
48
|
|
|
140
49
|
def _check_actionable_client(
|
|
@@ -165,6 +74,9 @@ def start_client(
|
|
|
165
74
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
166
75
|
insecure: Optional[bool] = None,
|
|
167
76
|
transport: Optional[str] = None,
|
|
77
|
+
authentication_keys: Optional[
|
|
78
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
79
|
+
] = None,
|
|
168
80
|
max_retries: Optional[int] = None,
|
|
169
81
|
max_wait_time: Optional[float] = None,
|
|
170
82
|
) -> None:
|
|
@@ -249,6 +161,7 @@ def start_client(
|
|
|
249
161
|
root_certificates=root_certificates,
|
|
250
162
|
insecure=insecure,
|
|
251
163
|
transport=transport,
|
|
164
|
+
authentication_keys=authentication_keys,
|
|
252
165
|
max_retries=max_retries,
|
|
253
166
|
max_wait_time=max_wait_time,
|
|
254
167
|
)
|
|
@@ -269,6 +182,9 @@ def _start_client_internal(
|
|
|
269
182
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
270
183
|
insecure: Optional[bool] = None,
|
|
271
184
|
transport: Optional[str] = None,
|
|
185
|
+
authentication_keys: Optional[
|
|
186
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
187
|
+
] = None,
|
|
272
188
|
max_retries: Optional[int] = None,
|
|
273
189
|
max_wait_time: Optional[float] = None,
|
|
274
190
|
) -> None:
|
|
@@ -393,6 +309,7 @@ def _start_client_internal(
|
|
|
393
309
|
retry_invoker,
|
|
394
310
|
grpc_max_message_length,
|
|
395
311
|
root_certificates,
|
|
312
|
+
authentication_keys,
|
|
396
313
|
) as conn:
|
|
397
314
|
# pylint: disable-next=W0612
|
|
398
315
|
receive, send, create_node, delete_node, get_run = conn
|
|
@@ -606,7 +523,14 @@ def start_numpy_client(
|
|
|
606
523
|
|
|
607
524
|
def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
|
|
608
525
|
Callable[
|
|
609
|
-
[
|
|
526
|
+
[
|
|
527
|
+
str,
|
|
528
|
+
bool,
|
|
529
|
+
RetryInvoker,
|
|
530
|
+
int,
|
|
531
|
+
Union[bytes, str, None],
|
|
532
|
+
Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]],
|
|
533
|
+
],
|
|
610
534
|
ContextManager[
|
|
611
535
|
Tuple[
|
|
612
536
|
Callable[[], Optional[Message]],
|
|
@@ -22,6 +22,8 @@ from pathlib import Path
|
|
|
22
22
|
from queue import Queue
|
|
23
23
|
from typing import Callable, Iterator, Optional, Tuple, Union, cast
|
|
24
24
|
|
|
25
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
26
|
+
|
|
25
27
|
from flwr.common import (
|
|
26
28
|
DEFAULT_TTL,
|
|
27
29
|
GRPC_MAX_MESSAGE_LENGTH,
|
|
@@ -56,12 +58,15 @@ def on_channel_state_change(channel_connectivity: str) -> None:
|
|
|
56
58
|
|
|
57
59
|
|
|
58
60
|
@contextmanager
|
|
59
|
-
def grpc_connection( # pylint: disable=R0915
|
|
61
|
+
def grpc_connection( # pylint: disable=R0913, R0915
|
|
60
62
|
server_address: str,
|
|
61
63
|
insecure: bool,
|
|
62
64
|
retry_invoker: RetryInvoker, # pylint: disable=unused-argument
|
|
63
65
|
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
|
|
64
66
|
root_certificates: Optional[Union[bytes, str]] = None,
|
|
67
|
+
authentication_keys: Optional[ # pylint: disable=unused-argument
|
|
68
|
+
Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
|
|
69
|
+
] = None,
|
|
65
70
|
) -> Iterator[
|
|
66
71
|
Tuple[
|
|
67
72
|
Callable[[], Optional[Message]],
|