flwr-nightly 1.9.0.dev20240420__py3-none-any.whl → 1.9.0.dev20240423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/new/new.py +1 -0
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +94 -0
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -0
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +24 -0
- flwr/client/supernode/app.py +16 -4
- flwr/server/compat/app.py +2 -2
- flwr/server/compat/driver_client_proxy.py +1 -1
- flwr/server/driver/__init__.py +3 -0
- flwr/server/driver/driver.py +12 -242
- flwr/server/driver/grpc_driver.py +306 -0
- flwr/server/run_serverapp.py +4 -4
- flwr/simulation/run_simulation.py +2 -2
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240423.dist-info}/METADATA +1 -1
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240423.dist-info}/RECORD +17 -14
- flwr/server/driver/abc_driver.py +0 -140
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240423.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240423.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240423.dist-info}/entry_points.txt +0 -0
flwr/cli/new/new.py
CHANGED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""$project_name: A Flower / Scikit-Learn app."""
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from flwr.client import NumPyClient, ClientApp
|
|
7
|
+
from flwr_datasets import FederatedDataset
|
|
8
|
+
from sklearn.linear_model import LogisticRegression
|
|
9
|
+
from sklearn.metrics import log_loss
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_model_parameters(model):
|
|
13
|
+
if model.fit_intercept:
|
|
14
|
+
params = [
|
|
15
|
+
model.coef_,
|
|
16
|
+
model.intercept_,
|
|
17
|
+
]
|
|
18
|
+
else:
|
|
19
|
+
params = [model.coef_]
|
|
20
|
+
return params
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def set_model_params(model, params):
|
|
24
|
+
model.coef_ = params[0]
|
|
25
|
+
if model.fit_intercept:
|
|
26
|
+
model.intercept_ = params[1]
|
|
27
|
+
return model
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def set_initial_params(model):
|
|
31
|
+
n_classes = 10 # MNIST has 10 classes
|
|
32
|
+
n_features = 784 # Number of features in dataset
|
|
33
|
+
model.classes_ = np.array([i for i in range(10)])
|
|
34
|
+
|
|
35
|
+
model.coef_ = np.zeros((n_classes, n_features))
|
|
36
|
+
if model.fit_intercept:
|
|
37
|
+
model.intercept_ = np.zeros((n_classes,))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class FlowerClient(NumPyClient):
|
|
41
|
+
def __init__(self, model, X_train, X_test, y_train, y_test):
|
|
42
|
+
self.model = model
|
|
43
|
+
self.X_train = X_train
|
|
44
|
+
self.X_test = X_test
|
|
45
|
+
self.y_train = y_train
|
|
46
|
+
self.y_test = y_test
|
|
47
|
+
|
|
48
|
+
def get_parameters(self, config):
|
|
49
|
+
return get_model_parameters(self.model)
|
|
50
|
+
|
|
51
|
+
def fit(self, parameters, config):
|
|
52
|
+
set_model_params(self.model, parameters)
|
|
53
|
+
|
|
54
|
+
# Ignore convergence failure due to low local epochs
|
|
55
|
+
with warnings.catch_warnings():
|
|
56
|
+
warnings.simplefilter("ignore")
|
|
57
|
+
self.model.fit(self.X_train, self.y_train)
|
|
58
|
+
|
|
59
|
+
return get_model_parameters(self.model), len(self.X_train), {}
|
|
60
|
+
|
|
61
|
+
def evaluate(self, parameters, config):
|
|
62
|
+
set_model_params(self.model, parameters)
|
|
63
|
+
|
|
64
|
+
loss = log_loss(self.y_test, self.model.predict_proba(self.X_test))
|
|
65
|
+
accuracy = self.model.score(self.X_test, self.y_test)
|
|
66
|
+
|
|
67
|
+
return loss, len(self.X_test), {"accuracy": accuracy}
|
|
68
|
+
|
|
69
|
+
fds = FederatedDataset(dataset="mnist", partitioners={"train": 2})
|
|
70
|
+
|
|
71
|
+
def client_fn(cid: str):
|
|
72
|
+
dataset = fds.load_partition(int(cid), "train").with_format("numpy")
|
|
73
|
+
|
|
74
|
+
X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
|
|
75
|
+
|
|
76
|
+
# Split the on edge data: 80% train, 20% test
|
|
77
|
+
X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
|
|
78
|
+
y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]
|
|
79
|
+
|
|
80
|
+
# Create LogisticRegression Model
|
|
81
|
+
model = LogisticRegression(
|
|
82
|
+
penalty="l2",
|
|
83
|
+
max_iter=1, # local epoch
|
|
84
|
+
warm_start=True, # prevent refreshing weights when fitting
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Setting initial parameters, akin to model.compile for keras models
|
|
88
|
+
set_initial_params(model)
|
|
89
|
+
|
|
90
|
+
return FlowerClient(model, X_train, X_test, y_train, y_test).to_client()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# Flower ClientApp
|
|
94
|
+
app = ClientApp(client_fn=client_fn)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""$project_name: A Flower / Scikit-Learn app."""
|
|
2
|
+
|
|
3
|
+
from flwr.server import ServerApp, ServerConfig
|
|
4
|
+
from flwr.server.strategy import FedAvg
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
strategy = FedAvg(
|
|
8
|
+
fraction_fit=1.0,
|
|
9
|
+
fraction_evaluate=1.0,
|
|
10
|
+
min_available_clients=2,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
# Create ServerApp
|
|
14
|
+
app = ServerApp(
|
|
15
|
+
config=ServerConfig(num_rounds=3),
|
|
16
|
+
strategy=strategy,
|
|
17
|
+
)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "$project_name"
|
|
7
|
+
version = "1.0.0"
|
|
8
|
+
description = ""
|
|
9
|
+
authors = [
|
|
10
|
+
{ name = "The Flower Authors", email = "hello@flower.ai" },
|
|
11
|
+
]
|
|
12
|
+
license = {text = "Apache License (2.0)"}
|
|
13
|
+
dependencies = [
|
|
14
|
+
"flwr[simulation]>=1.8.0,<2.0",
|
|
15
|
+
"flwr-datasets[vision]>=0.0.2,<1.0.0",
|
|
16
|
+
"scikit-learn>=1.1.1",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
[tool.hatch.build.targets.wheel]
|
|
20
|
+
packages = ["."]
|
|
21
|
+
|
|
22
|
+
[flower.components]
|
|
23
|
+
serverapp = "$project_name.server:app"
|
|
24
|
+
clientapp = "$project_name.client:app"
|
flwr/client/supernode/app.py
CHANGED
|
@@ -28,12 +28,11 @@ def run_supernode() -> None:
|
|
|
28
28
|
|
|
29
29
|
event(EventType.RUN_SUPERNODE_ENTER)
|
|
30
30
|
|
|
31
|
-
|
|
31
|
+
_ = _parse_args_run_supernode().parse_args()
|
|
32
32
|
|
|
33
33
|
log(
|
|
34
34
|
DEBUG,
|
|
35
|
-
"Flower
|
|
36
|
-
getattr(args, "client-app"),
|
|
35
|
+
"Flower SuperNode starting...",
|
|
37
36
|
)
|
|
38
37
|
|
|
39
38
|
# Graceful shutdown
|
|
@@ -48,7 +47,16 @@ def _parse_args_run_supernode() -> argparse.ArgumentParser:
|
|
|
48
47
|
description="Start a Flower SuperNode",
|
|
49
48
|
)
|
|
50
49
|
|
|
51
|
-
|
|
50
|
+
parser.add_argument(
|
|
51
|
+
"client-app",
|
|
52
|
+
nargs="?",
|
|
53
|
+
default="",
|
|
54
|
+
help="For example: `client:app` or `project.package.module:wrapper.app`. "
|
|
55
|
+
"This is optional and serves as the default ClientApp to be loaded when "
|
|
56
|
+
"the ServerApp does not specify `fab_id` and `fab_version`. "
|
|
57
|
+
"If not provided, defaults to an empty string.",
|
|
58
|
+
)
|
|
59
|
+
_parse_args_common(parser)
|
|
52
60
|
|
|
53
61
|
return parser
|
|
54
62
|
|
|
@@ -59,6 +67,10 @@ def parse_args_run_client_app(parser: argparse.ArgumentParser) -> None:
|
|
|
59
67
|
"client-app",
|
|
60
68
|
help="For example: `client:app` or `project.package.module:wrapper.app`",
|
|
61
69
|
)
|
|
70
|
+
_parse_args_common(parser)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _parse_args_common(parser: argparse.ArgumentParser) -> None:
|
|
62
74
|
parser.add_argument(
|
|
63
75
|
"--insecure",
|
|
64
76
|
action="store_true",
|
flwr/server/compat/app.py
CHANGED
|
@@ -29,7 +29,7 @@ from flwr.server.server import Server, init_defaults, run_fl
|
|
|
29
29
|
from flwr.server.server_config import ServerConfig
|
|
30
30
|
from flwr.server.strategy import Strategy
|
|
31
31
|
|
|
32
|
-
from ..driver import Driver
|
|
32
|
+
from ..driver import Driver, GrpcDriver
|
|
33
33
|
from .app_utils import start_update_client_manager_thread
|
|
34
34
|
|
|
35
35
|
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
@@ -114,7 +114,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
|
|
|
114
114
|
# Create the Driver
|
|
115
115
|
if isinstance(root_certificates, str):
|
|
116
116
|
root_certificates = Path(root_certificates).read_bytes()
|
|
117
|
-
driver =
|
|
117
|
+
driver = GrpcDriver(
|
|
118
118
|
driver_service_address=address, root_certificates=root_certificates
|
|
119
119
|
)
|
|
120
120
|
|
|
@@ -25,7 +25,7 @@ from flwr.common import serde
|
|
|
25
25
|
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
|
|
26
26
|
from flwr.server.client_proxy import ClientProxy
|
|
27
27
|
|
|
28
|
-
from ..driver.
|
|
28
|
+
from ..driver.grpc_driver import GrpcDriverHelper
|
|
29
29
|
|
|
30
30
|
SLEEP_TIME = 1
|
|
31
31
|
|
flwr/server/driver/__init__.py
CHANGED
flwr/server/driver/driver.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,180 +12,19 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""Driver (abstract base class)."""
|
|
16
16
|
|
|
17
|
-
import time
|
|
18
|
-
import warnings
|
|
19
|
-
from logging import DEBUG, ERROR, WARNING
|
|
20
|
-
from typing import Iterable, List, Optional, Tuple
|
|
21
17
|
|
|
22
|
-
import
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from typing import Iterable, List, Optional
|
|
23
20
|
|
|
24
|
-
from flwr.common import
|
|
25
|
-
from flwr.common.grpc import create_channel
|
|
26
|
-
from flwr.common.logger import log
|
|
27
|
-
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
28
|
-
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
29
|
-
CreateRunRequest,
|
|
30
|
-
CreateRunResponse,
|
|
31
|
-
GetNodesRequest,
|
|
32
|
-
GetNodesResponse,
|
|
33
|
-
PullTaskResRequest,
|
|
34
|
-
PullTaskResResponse,
|
|
35
|
-
PushTaskInsRequest,
|
|
36
|
-
PushTaskInsResponse,
|
|
37
|
-
)
|
|
38
|
-
from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
|
|
39
|
-
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
40
|
-
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
21
|
+
from flwr.common import Message, RecordSet
|
|
41
22
|
|
|
42
|
-
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
43
23
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
|
|
48
|
-
`GrpcDriverHelper` methods.
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
class GrpcDriverHelper:
|
|
53
|
-
"""`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
|
|
54
|
-
|
|
55
|
-
def __init__(
|
|
56
|
-
self,
|
|
57
|
-
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
58
|
-
root_certificates: Optional[bytes] = None,
|
|
59
|
-
) -> None:
|
|
60
|
-
self.driver_service_address = driver_service_address
|
|
61
|
-
self.root_certificates = root_certificates
|
|
62
|
-
self.channel: Optional[grpc.Channel] = None
|
|
63
|
-
self.stub: Optional[DriverStub] = None
|
|
64
|
-
|
|
65
|
-
def connect(self) -> None:
|
|
66
|
-
"""Connect to the Driver API."""
|
|
67
|
-
event(EventType.DRIVER_CONNECT)
|
|
68
|
-
if self.channel is not None or self.stub is not None:
|
|
69
|
-
log(WARNING, "Already connected")
|
|
70
|
-
return
|
|
71
|
-
self.channel = create_channel(
|
|
72
|
-
server_address=self.driver_service_address,
|
|
73
|
-
insecure=(self.root_certificates is None),
|
|
74
|
-
root_certificates=self.root_certificates,
|
|
75
|
-
)
|
|
76
|
-
self.stub = DriverStub(self.channel)
|
|
77
|
-
log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
|
|
78
|
-
|
|
79
|
-
def disconnect(self) -> None:
|
|
80
|
-
"""Disconnect from the Driver API."""
|
|
81
|
-
event(EventType.DRIVER_DISCONNECT)
|
|
82
|
-
if self.channel is None or self.stub is None:
|
|
83
|
-
log(DEBUG, "Already disconnected")
|
|
84
|
-
return
|
|
85
|
-
channel = self.channel
|
|
86
|
-
self.channel = None
|
|
87
|
-
self.stub = None
|
|
88
|
-
channel.close()
|
|
89
|
-
log(DEBUG, "[Driver] Disconnected")
|
|
90
|
-
|
|
91
|
-
def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
|
|
92
|
-
"""Request for run ID."""
|
|
93
|
-
# Check if channel is open
|
|
94
|
-
if self.stub is None:
|
|
95
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
96
|
-
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
97
|
-
|
|
98
|
-
# Call Driver API
|
|
99
|
-
res: CreateRunResponse = self.stub.CreateRun(request=req)
|
|
100
|
-
return res
|
|
101
|
-
|
|
102
|
-
def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
|
|
103
|
-
"""Get client IDs."""
|
|
104
|
-
# Check if channel is open
|
|
105
|
-
if self.stub is None:
|
|
106
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
107
|
-
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
108
|
-
|
|
109
|
-
# Call gRPC Driver API
|
|
110
|
-
res: GetNodesResponse = self.stub.GetNodes(request=req)
|
|
111
|
-
return res
|
|
112
|
-
|
|
113
|
-
def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
|
|
114
|
-
"""Schedule tasks."""
|
|
115
|
-
# Check if channel is open
|
|
116
|
-
if self.stub is None:
|
|
117
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
118
|
-
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
119
|
-
|
|
120
|
-
# Call gRPC Driver API
|
|
121
|
-
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
|
|
122
|
-
return res
|
|
123
|
-
|
|
124
|
-
def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
|
|
125
|
-
"""Get task results."""
|
|
126
|
-
# Check if channel is open
|
|
127
|
-
if self.stub is None:
|
|
128
|
-
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
129
|
-
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
130
|
-
|
|
131
|
-
# Call Driver API
|
|
132
|
-
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
|
|
133
|
-
return res
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
class Driver:
|
|
137
|
-
"""`Driver` class provides an interface to the Driver API.
|
|
138
|
-
|
|
139
|
-
Parameters
|
|
140
|
-
----------
|
|
141
|
-
driver_service_address : Optional[str]
|
|
142
|
-
The IPv4 or IPv6 address of the Driver API server.
|
|
143
|
-
Defaults to `"[::]:9091"`.
|
|
144
|
-
certificates : bytes (default: None)
|
|
145
|
-
Tuple containing root certificate, server certificate, and private key
|
|
146
|
-
to start a secure SSL-enabled server. The tuple is expected to have
|
|
147
|
-
three bytes elements in the following order:
|
|
148
|
-
|
|
149
|
-
* CA certificate.
|
|
150
|
-
* server certificate.
|
|
151
|
-
* server private key.
|
|
152
|
-
"""
|
|
153
|
-
|
|
154
|
-
def __init__(
|
|
155
|
-
self,
|
|
156
|
-
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
157
|
-
root_certificates: Optional[bytes] = None,
|
|
158
|
-
) -> None:
|
|
159
|
-
self.addr = driver_service_address
|
|
160
|
-
self.root_certificates = root_certificates
|
|
161
|
-
self.grpc_driver_helper: Optional[GrpcDriverHelper] = None
|
|
162
|
-
self.run_id: Optional[int] = None
|
|
163
|
-
self.node = Node(node_id=0, anonymous=True)
|
|
164
|
-
|
|
165
|
-
def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
|
|
166
|
-
# Check if the GrpcDriverHelper is initialized
|
|
167
|
-
if self.grpc_driver_helper is None or self.run_id is None:
|
|
168
|
-
# Connect and create run
|
|
169
|
-
self.grpc_driver_helper = GrpcDriverHelper(
|
|
170
|
-
driver_service_address=self.addr,
|
|
171
|
-
root_certificates=self.root_certificates,
|
|
172
|
-
)
|
|
173
|
-
self.grpc_driver_helper.connect()
|
|
174
|
-
res = self.grpc_driver_helper.create_run(CreateRunRequest())
|
|
175
|
-
self.run_id = res.run_id
|
|
176
|
-
return self.grpc_driver_helper, self.run_id
|
|
177
|
-
|
|
178
|
-
def _check_message(self, message: Message) -> None:
|
|
179
|
-
# Check if the message is valid
|
|
180
|
-
if not (
|
|
181
|
-
message.metadata.run_id == self.run_id
|
|
182
|
-
and message.metadata.src_node_id == self.node.node_id
|
|
183
|
-
and message.metadata.message_id == ""
|
|
184
|
-
and message.metadata.reply_to_message == ""
|
|
185
|
-
and message.metadata.ttl > 0
|
|
186
|
-
):
|
|
187
|
-
raise ValueError(f"Invalid message: {message}")
|
|
24
|
+
class Driver(ABC):
|
|
25
|
+
"""Abstract base Driver class for the Driver API."""
|
|
188
26
|
|
|
27
|
+
@abstractmethod
|
|
189
28
|
def create_message( # pylint: disable=too-many-arguments
|
|
190
29
|
self,
|
|
191
30
|
content: RecordSet,
|
|
@@ -223,35 +62,12 @@ class Driver:
|
|
|
223
62
|
message : Message
|
|
224
63
|
A new `Message` instance with the specified content and metadata.
|
|
225
64
|
"""
|
|
226
|
-
_, run_id = self._get_grpc_driver_helper_and_run_id()
|
|
227
|
-
if ttl:
|
|
228
|
-
warnings.warn(
|
|
229
|
-
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
230
|
-
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
231
|
-
"version of Flower.",
|
|
232
|
-
stacklevel=2,
|
|
233
|
-
)
|
|
234
|
-
|
|
235
|
-
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
236
|
-
metadata = Metadata(
|
|
237
|
-
run_id=run_id,
|
|
238
|
-
message_id="", # Will be set by the server
|
|
239
|
-
src_node_id=self.node.node_id,
|
|
240
|
-
dst_node_id=dst_node_id,
|
|
241
|
-
reply_to_message="",
|
|
242
|
-
group_id=group_id,
|
|
243
|
-
ttl=ttl_,
|
|
244
|
-
message_type=message_type,
|
|
245
|
-
)
|
|
246
|
-
return Message(metadata=metadata, content=content)
|
|
247
65
|
|
|
66
|
+
@abstractmethod
|
|
248
67
|
def get_node_ids(self) -> List[int]:
|
|
249
68
|
"""Get node IDs."""
|
|
250
|
-
grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
|
|
251
|
-
# Call GrpcDriverHelper method
|
|
252
|
-
res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
|
|
253
|
-
return [node.node_id for node in res.nodes]
|
|
254
69
|
|
|
70
|
+
@abstractmethod
|
|
255
71
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
256
72
|
"""Push messages to specified node IDs.
|
|
257
73
|
|
|
@@ -269,22 +85,8 @@ class Driver:
|
|
|
269
85
|
An iterable of IDs for the messages that were sent, which can be used
|
|
270
86
|
to pull replies.
|
|
271
87
|
"""
|
|
272
|
-
grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id()
|
|
273
|
-
# Construct TaskIns
|
|
274
|
-
task_ins_list: List[TaskIns] = []
|
|
275
|
-
for msg in messages:
|
|
276
|
-
# Check message
|
|
277
|
-
self._check_message(msg)
|
|
278
|
-
# Convert Message to TaskIns
|
|
279
|
-
taskins = message_to_taskins(msg)
|
|
280
|
-
# Add to list
|
|
281
|
-
task_ins_list.append(taskins)
|
|
282
|
-
# Call GrpcDriverHelper method
|
|
283
|
-
res = grpc_driver_helper.push_task_ins(
|
|
284
|
-
PushTaskInsRequest(task_ins_list=task_ins_list)
|
|
285
|
-
)
|
|
286
|
-
return list(res.task_ids)
|
|
287
88
|
|
|
89
|
+
@abstractmethod
|
|
288
90
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
289
91
|
"""Pull messages based on message IDs.
|
|
290
92
|
|
|
@@ -301,15 +103,8 @@ class Driver:
|
|
|
301
103
|
messages : Iterable[Message]
|
|
302
104
|
An iterable of messages received.
|
|
303
105
|
"""
|
|
304
|
-
grpc_driver, _ = self._get_grpc_driver_helper_and_run_id()
|
|
305
|
-
# Pull TaskRes
|
|
306
|
-
res = grpc_driver.pull_task_res(
|
|
307
|
-
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
308
|
-
)
|
|
309
|
-
# Convert TaskRes to Message
|
|
310
|
-
msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
|
|
311
|
-
return msgs
|
|
312
106
|
|
|
107
|
+
@abstractmethod
|
|
313
108
|
def send_and_receive(
|
|
314
109
|
self,
|
|
315
110
|
messages: Iterable[Message],
|
|
@@ -343,28 +138,3 @@ class Driver:
|
|
|
343
138
|
replies for all sent messages. A message remains valid until its TTL,
|
|
344
139
|
which is not affected by `timeout`.
|
|
345
140
|
"""
|
|
346
|
-
# Push messages
|
|
347
|
-
msg_ids = set(self.push_messages(messages))
|
|
348
|
-
|
|
349
|
-
# Pull messages
|
|
350
|
-
end_time = time.time() + (timeout if timeout is not None else 0.0)
|
|
351
|
-
ret: List[Message] = []
|
|
352
|
-
while timeout is None or time.time() < end_time:
|
|
353
|
-
res_msgs = self.pull_messages(msg_ids)
|
|
354
|
-
ret.extend(res_msgs)
|
|
355
|
-
msg_ids.difference_update(
|
|
356
|
-
{msg.metadata.reply_to_message for msg in res_msgs}
|
|
357
|
-
)
|
|
358
|
-
if len(msg_ids) == 0:
|
|
359
|
-
break
|
|
360
|
-
# Sleep
|
|
361
|
-
time.sleep(3)
|
|
362
|
-
return ret
|
|
363
|
-
|
|
364
|
-
def close(self) -> None:
|
|
365
|
-
"""Disconnect from the SuperLink if connected."""
|
|
366
|
-
# Check if GrpcDriverHelper is initialized
|
|
367
|
-
if self.grpc_driver_helper is None:
|
|
368
|
-
return
|
|
369
|
-
# Disconnect
|
|
370
|
-
self.grpc_driver_helper.disconnect()
|
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
# Copyright 2022 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower gRPC Driver."""
|
|
16
|
+
|
|
17
|
+
import time
|
|
18
|
+
import warnings
|
|
19
|
+
from logging import DEBUG, ERROR, WARNING
|
|
20
|
+
from typing import Iterable, List, Optional, Tuple
|
|
21
|
+
|
|
22
|
+
import grpc
|
|
23
|
+
|
|
24
|
+
from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
|
|
25
|
+
from flwr.common.grpc import create_channel
|
|
26
|
+
from flwr.common.logger import log
|
|
27
|
+
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
28
|
+
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
29
|
+
CreateRunRequest,
|
|
30
|
+
CreateRunResponse,
|
|
31
|
+
GetNodesRequest,
|
|
32
|
+
GetNodesResponse,
|
|
33
|
+
PullTaskResRequest,
|
|
34
|
+
PullTaskResResponse,
|
|
35
|
+
PushTaskInsRequest,
|
|
36
|
+
PushTaskInsResponse,
|
|
37
|
+
)
|
|
38
|
+
from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
|
|
39
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
40
|
+
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
41
|
+
|
|
42
|
+
from .driver import Driver
|
|
43
|
+
|
|
44
|
+
DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"
|
|
45
|
+
|
|
46
|
+
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
|
|
47
|
+
[Driver] Error: Not connected.
|
|
48
|
+
|
|
49
|
+
Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
|
|
50
|
+
`GrpcDriverHelper` methods.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class GrpcDriverHelper:
|
|
55
|
+
"""`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
60
|
+
root_certificates: Optional[bytes] = None,
|
|
61
|
+
) -> None:
|
|
62
|
+
self.driver_service_address = driver_service_address
|
|
63
|
+
self.root_certificates = root_certificates
|
|
64
|
+
self.channel: Optional[grpc.Channel] = None
|
|
65
|
+
self.stub: Optional[DriverStub] = None
|
|
66
|
+
|
|
67
|
+
def connect(self) -> None:
|
|
68
|
+
"""Connect to the Driver API."""
|
|
69
|
+
event(EventType.DRIVER_CONNECT)
|
|
70
|
+
if self.channel is not None or self.stub is not None:
|
|
71
|
+
log(WARNING, "Already connected")
|
|
72
|
+
return
|
|
73
|
+
self.channel = create_channel(
|
|
74
|
+
server_address=self.driver_service_address,
|
|
75
|
+
insecure=(self.root_certificates is None),
|
|
76
|
+
root_certificates=self.root_certificates,
|
|
77
|
+
)
|
|
78
|
+
self.stub = DriverStub(self.channel)
|
|
79
|
+
log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
|
|
80
|
+
|
|
81
|
+
def disconnect(self) -> None:
|
|
82
|
+
"""Disconnect from the Driver API."""
|
|
83
|
+
event(EventType.DRIVER_DISCONNECT)
|
|
84
|
+
if self.channel is None or self.stub is None:
|
|
85
|
+
log(DEBUG, "Already disconnected")
|
|
86
|
+
return
|
|
87
|
+
channel = self.channel
|
|
88
|
+
self.channel = None
|
|
89
|
+
self.stub = None
|
|
90
|
+
channel.close()
|
|
91
|
+
log(DEBUG, "[Driver] Disconnected")
|
|
92
|
+
|
|
93
|
+
def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
|
|
94
|
+
"""Request for run ID."""
|
|
95
|
+
# Check if channel is open
|
|
96
|
+
if self.stub is None:
|
|
97
|
+
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
98
|
+
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
99
|
+
|
|
100
|
+
# Call Driver API
|
|
101
|
+
res: CreateRunResponse = self.stub.CreateRun(request=req)
|
|
102
|
+
return res
|
|
103
|
+
|
|
104
|
+
def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
|
|
105
|
+
"""Get client IDs."""
|
|
106
|
+
# Check if channel is open
|
|
107
|
+
if self.stub is None:
|
|
108
|
+
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
109
|
+
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
110
|
+
|
|
111
|
+
# Call gRPC Driver API
|
|
112
|
+
res: GetNodesResponse = self.stub.GetNodes(request=req)
|
|
113
|
+
return res
|
|
114
|
+
|
|
115
|
+
def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
|
|
116
|
+
"""Schedule tasks."""
|
|
117
|
+
# Check if channel is open
|
|
118
|
+
if self.stub is None:
|
|
119
|
+
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
120
|
+
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
121
|
+
|
|
122
|
+
# Call gRPC Driver API
|
|
123
|
+
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
|
|
124
|
+
return res
|
|
125
|
+
|
|
126
|
+
def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
|
|
127
|
+
"""Get task results."""
|
|
128
|
+
# Check if channel is open
|
|
129
|
+
if self.stub is None:
|
|
130
|
+
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
|
|
131
|
+
raise ConnectionError("`GrpcDriverHelper` instance not connected")
|
|
132
|
+
|
|
133
|
+
# Call Driver API
|
|
134
|
+
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
|
|
135
|
+
return res
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class GrpcDriver(Driver):
|
|
139
|
+
"""`Driver` class provides an interface to the Driver API.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
driver_service_address : Optional[str]
|
|
144
|
+
The IPv4 or IPv6 address of the Driver API server.
|
|
145
|
+
Defaults to `"[::]:9091"`.
|
|
146
|
+
certificates : bytes (default: None)
|
|
147
|
+
Tuple containing root certificate, server certificate, and private key
|
|
148
|
+
to start a secure SSL-enabled server. The tuple is expected to have
|
|
149
|
+
three bytes elements in the following order:
|
|
150
|
+
|
|
151
|
+
* CA certificate.
|
|
152
|
+
* server certificate.
|
|
153
|
+
* server private key.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
|
|
159
|
+
root_certificates: Optional[bytes] = None,
|
|
160
|
+
) -> None:
|
|
161
|
+
self.addr = driver_service_address
|
|
162
|
+
self.root_certificates = root_certificates
|
|
163
|
+
self.grpc_driver_helper: Optional[GrpcDriverHelper] = None
|
|
164
|
+
self.run_id: Optional[int] = None
|
|
165
|
+
self.node = Node(node_id=0, anonymous=True)
|
|
166
|
+
|
|
167
|
+
def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
|
|
168
|
+
# Check if the GrpcDriverHelper is initialized
|
|
169
|
+
if self.grpc_driver_helper is None or self.run_id is None:
|
|
170
|
+
# Connect and create run
|
|
171
|
+
self.grpc_driver_helper = GrpcDriverHelper(
|
|
172
|
+
driver_service_address=self.addr,
|
|
173
|
+
root_certificates=self.root_certificates,
|
|
174
|
+
)
|
|
175
|
+
self.grpc_driver_helper.connect()
|
|
176
|
+
res = self.grpc_driver_helper.create_run(CreateRunRequest())
|
|
177
|
+
self.run_id = res.run_id
|
|
178
|
+
return self.grpc_driver_helper, self.run_id
|
|
179
|
+
|
|
180
|
+
def _check_message(self, message: Message) -> None:
|
|
181
|
+
# Check if the message is valid
|
|
182
|
+
if not (
|
|
183
|
+
message.metadata.run_id == self.run_id
|
|
184
|
+
and message.metadata.src_node_id == self.node.node_id
|
|
185
|
+
and message.metadata.message_id == ""
|
|
186
|
+
and message.metadata.reply_to_message == ""
|
|
187
|
+
and message.metadata.ttl > 0
|
|
188
|
+
):
|
|
189
|
+
raise ValueError(f"Invalid message: {message}")
|
|
190
|
+
|
|
191
|
+
def create_message( # pylint: disable=too-many-arguments
|
|
192
|
+
self,
|
|
193
|
+
content: RecordSet,
|
|
194
|
+
message_type: str,
|
|
195
|
+
dst_node_id: int,
|
|
196
|
+
group_id: str,
|
|
197
|
+
ttl: Optional[float] = None,
|
|
198
|
+
) -> Message:
|
|
199
|
+
"""Create a new message with specified parameters.
|
|
200
|
+
|
|
201
|
+
This method constructs a new `Message` with given content and metadata.
|
|
202
|
+
The `run_id` and `src_node_id` will be set automatically.
|
|
203
|
+
"""
|
|
204
|
+
_, run_id = self._get_grpc_driver_helper_and_run_id()
|
|
205
|
+
if ttl:
|
|
206
|
+
warnings.warn(
|
|
207
|
+
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
208
|
+
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
209
|
+
"version of Flower.",
|
|
210
|
+
stacklevel=2,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
214
|
+
metadata = Metadata(
|
|
215
|
+
run_id=run_id,
|
|
216
|
+
message_id="", # Will be set by the server
|
|
217
|
+
src_node_id=self.node.node_id,
|
|
218
|
+
dst_node_id=dst_node_id,
|
|
219
|
+
reply_to_message="",
|
|
220
|
+
group_id=group_id,
|
|
221
|
+
ttl=ttl_,
|
|
222
|
+
message_type=message_type,
|
|
223
|
+
)
|
|
224
|
+
return Message(metadata=metadata, content=content)
|
|
225
|
+
|
|
226
|
+
def get_node_ids(self) -> List[int]:
|
|
227
|
+
"""Get node IDs."""
|
|
228
|
+
grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
|
|
229
|
+
# Call GrpcDriverHelper method
|
|
230
|
+
res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
|
|
231
|
+
return [node.node_id for node in res.nodes]
|
|
232
|
+
|
|
233
|
+
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
234
|
+
"""Push messages to specified node IDs.
|
|
235
|
+
|
|
236
|
+
This method takes an iterable of messages and sends each message
|
|
237
|
+
to the node specified in `dst_node_id`.
|
|
238
|
+
"""
|
|
239
|
+
grpc_driver_helper, _ = self._get_grpc_driver_helper_and_run_id()
|
|
240
|
+
# Construct TaskIns
|
|
241
|
+
task_ins_list: List[TaskIns] = []
|
|
242
|
+
for msg in messages:
|
|
243
|
+
# Check message
|
|
244
|
+
self._check_message(msg)
|
|
245
|
+
# Convert Message to TaskIns
|
|
246
|
+
taskins = message_to_taskins(msg)
|
|
247
|
+
# Add to list
|
|
248
|
+
task_ins_list.append(taskins)
|
|
249
|
+
# Call GrpcDriverHelper method
|
|
250
|
+
res = grpc_driver_helper.push_task_ins(
|
|
251
|
+
PushTaskInsRequest(task_ins_list=task_ins_list)
|
|
252
|
+
)
|
|
253
|
+
return list(res.task_ids)
|
|
254
|
+
|
|
255
|
+
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
256
|
+
"""Pull messages based on message IDs.
|
|
257
|
+
|
|
258
|
+
This method is used to collect messages from the SuperLink that correspond to a
|
|
259
|
+
set of given message IDs.
|
|
260
|
+
"""
|
|
261
|
+
grpc_driver, _ = self._get_grpc_driver_helper_and_run_id()
|
|
262
|
+
# Pull TaskRes
|
|
263
|
+
res = grpc_driver.pull_task_res(
|
|
264
|
+
PullTaskResRequest(node=self.node, task_ids=message_ids)
|
|
265
|
+
)
|
|
266
|
+
# Convert TaskRes to Message
|
|
267
|
+
msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
|
|
268
|
+
return msgs
|
|
269
|
+
|
|
270
|
+
def send_and_receive(
|
|
271
|
+
self,
|
|
272
|
+
messages: Iterable[Message],
|
|
273
|
+
*,
|
|
274
|
+
timeout: Optional[float] = None,
|
|
275
|
+
) -> Iterable[Message]:
|
|
276
|
+
"""Push messages to specified node IDs and pull the reply messages.
|
|
277
|
+
|
|
278
|
+
This method sends a list of messages to their destination node IDs and then
|
|
279
|
+
waits for the replies. It continues to pull replies until either all replies are
|
|
280
|
+
received or the specified timeout duration is exceeded.
|
|
281
|
+
"""
|
|
282
|
+
# Push messages
|
|
283
|
+
msg_ids = set(self.push_messages(messages))
|
|
284
|
+
|
|
285
|
+
# Pull messages
|
|
286
|
+
end_time = time.time() + (timeout if timeout is not None else 0.0)
|
|
287
|
+
ret: List[Message] = []
|
|
288
|
+
while timeout is None or time.time() < end_time:
|
|
289
|
+
res_msgs = self.pull_messages(msg_ids)
|
|
290
|
+
ret.extend(res_msgs)
|
|
291
|
+
msg_ids.difference_update(
|
|
292
|
+
{msg.metadata.reply_to_message for msg in res_msgs}
|
|
293
|
+
)
|
|
294
|
+
if len(msg_ids) == 0:
|
|
295
|
+
break
|
|
296
|
+
# Sleep
|
|
297
|
+
time.sleep(3)
|
|
298
|
+
return ret
|
|
299
|
+
|
|
300
|
+
def close(self) -> None:
|
|
301
|
+
"""Disconnect from the SuperLink if connected."""
|
|
302
|
+
# Check if GrpcDriverHelper is initialized
|
|
303
|
+
if self.grpc_driver_helper is None:
|
|
304
|
+
return
|
|
305
|
+
# Disconnect
|
|
306
|
+
self.grpc_driver_helper.disconnect()
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -25,7 +25,7 @@ from flwr.common import Context, EventType, RecordSet, event
|
|
|
25
25
|
from flwr.common.logger import log, update_console_handler
|
|
26
26
|
from flwr.common.object_ref import load_app
|
|
27
27
|
|
|
28
|
-
from .driver
|
|
28
|
+
from .driver import Driver, GrpcDriver
|
|
29
29
|
from .server_app import LoadServerAppError, ServerApp
|
|
30
30
|
|
|
31
31
|
|
|
@@ -128,13 +128,13 @@ def run_server_app() -> None:
|
|
|
128
128
|
server_app_dir = args.dir
|
|
129
129
|
server_app_attr = getattr(args, "server-app")
|
|
130
130
|
|
|
131
|
-
# Initialize
|
|
132
|
-
driver =
|
|
131
|
+
# Initialize GrpcDriver
|
|
132
|
+
driver = GrpcDriver(
|
|
133
133
|
driver_service_address=args.server,
|
|
134
134
|
root_certificates=root_certificates,
|
|
135
135
|
)
|
|
136
136
|
|
|
137
|
-
# Run the
|
|
137
|
+
# Run the ServerApp with the Driver
|
|
138
138
|
run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
|
|
139
139
|
|
|
140
140
|
# Clean up
|
|
@@ -29,7 +29,7 @@ import grpc
|
|
|
29
29
|
from flwr.client import ClientApp
|
|
30
30
|
from flwr.common import EventType, event, log
|
|
31
31
|
from flwr.common.typing import ConfigsRecordValues
|
|
32
|
-
from flwr.server.driver
|
|
32
|
+
from flwr.server.driver import Driver, GrpcDriver
|
|
33
33
|
from flwr.server.run_serverapp import run
|
|
34
34
|
from flwr.server.server_app import ServerApp
|
|
35
35
|
from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
|
|
@@ -204,7 +204,7 @@ def _main_loop(
|
|
|
204
204
|
serverapp_th = None
|
|
205
205
|
try:
|
|
206
206
|
# Initialize Driver
|
|
207
|
-
driver =
|
|
207
|
+
driver = GrpcDriver(
|
|
208
208
|
driver_service_address=driver_api_address,
|
|
209
209
|
root_certificates=None,
|
|
210
210
|
)
|
{flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240423.dist-info}/RECORD
RENAMED
|
@@ -4,7 +4,7 @@ flwr/cli/app.py,sha256=38thPnMydBmNAxNE9mz4By-KdRUhJfoUgeDuAxMYF_U,1095
|
|
|
4
4
|
flwr/cli/config_utils.py,sha256=1wTPQqOU2fKeU4FP5KyG0xMa0F-qy8x1m2WvztPORb4,5597
|
|
5
5
|
flwr/cli/example.py,sha256=1bGDYll3BXQY2kRqSN-oICqS5n1b9m0g0RvXTopXHl4,2215
|
|
6
6
|
flwr/cli/new/__init__.py,sha256=cQzK1WH4JP2awef1t2UQ2xjl1agVEz9rwutV18SWV1k,789
|
|
7
|
-
flwr/cli/new/new.py,sha256=
|
|
7
|
+
flwr/cli/new/new.py,sha256=hqcHjun3keeREegDrdLJMPHKkVBYIN4HUUeCl3hzVgI,5404
|
|
8
8
|
flwr/cli/new/templates/__init__.py,sha256=4luU8RL-CK8JJCstQ_ON809W9bNTkY1l9zSaPKBkgwY,725
|
|
9
9
|
flwr/cli/new/templates/app/.gitignore.tpl,sha256=XixnHdyeMB2vwkGtGnwHqoWpH-9WChdyG0GXe57duhc,3078
|
|
10
10
|
flwr/cli/new/templates/app/README.md.tpl,sha256=_qGtgpKYKoCJVjQnvlBMKvFs_1gzTcL908I3KJg0oAM,668
|
|
@@ -13,13 +13,16 @@ flwr/cli/new/templates/app/code/__init__.py,sha256=EM6vfvgAILKPaPn7H1wMV1Wi01WyZ
|
|
|
13
13
|
flwr/cli/new/templates/app/code/__init__.py.tpl,sha256=olwrBeJemHNBWvjc6gJURloFRqW40dAy7FRQA5pDqHU,21
|
|
14
14
|
flwr/cli/new/templates/app/code/client.numpy.py.tpl,sha256=mTh7Y_jOJrPUvDYHVJy4wJCnjXZV_q-jlDkB07U5GSk,521
|
|
15
15
|
flwr/cli/new/templates/app/code/client.pytorch.py.tpl,sha256=671daPcdZaC4Z5k-dqmCovfb2_FShGmqfjwaR8y6EC8,1173
|
|
16
|
+
flwr/cli/new/templates/app/code/client.sklearn.py.tpl,sha256=S71SZiHaRXtKqUk3m5Elc_c6HhKAIKLalrKOQ3p20No,2801
|
|
16
17
|
flwr/cli/new/templates/app/code/client.tensorflow.py.tpl,sha256=N9SbnI65r2K9FHV_wn4JSpmVeyYpD0qEMehbHcGm4t0,1911
|
|
17
18
|
flwr/cli/new/templates/app/code/server.numpy.py.tpl,sha256=fRxrDXV7pB1aDhQUXMBmrCsC1zp0uKwsBxZBx1JzbHA,248
|
|
18
19
|
flwr/cli/new/templates/app/code/server.pytorch.py.tpl,sha256=xtKvUivNMzgOcLSOtnjWouJzIFbXdUQVYMm27uwyJpI,594
|
|
20
|
+
flwr/cli/new/templates/app/code/server.sklearn.py.tpl,sha256=cLzOpQzGIUzEazuFsjBpXAQUNPy6in6zR33SCqhix6o,341
|
|
19
21
|
flwr/cli/new/templates/app/code/server.tensorflow.py.tpl,sha256=GUGH8c_6cxgUB9obVJPaA4thxI7OVXsItyfQDsn9E5k,371
|
|
20
22
|
flwr/cli/new/templates/app/code/task.pytorch.py.tpl,sha256=NvajdZN-eTyfdqKK0v2MrvWITXw9BjJ3Ri5c1haPJDs,3684
|
|
21
23
|
flwr/cli/new/templates/app/pyproject.numpy.toml.tpl,sha256=0oTH0lY7q-PpRV4HA5woxJ1eWIgZRFcFsHa7-1lULIQ,489
|
|
22
24
|
flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=GYbMAFD90JBRvy8fJbLU7nDITD3sxHv1TncQrg6mjEE,558
|
|
25
|
+
flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl,sha256=7p6s2jJpC8ZO-TfiJ0cE3fzkIhc4ndj9SY1hiYvSM5Q,538
|
|
23
26
|
flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl,sha256=7I8BYtE28cnc7ZiOlOp6_zeLsjLRlwa0Y4sjoP7r9VU,537
|
|
24
27
|
flwr/cli/run/__init__.py,sha256=oCd6HmQDx-sqver1gecgx-uMA38BLTSiiKpl7RGNceg,789
|
|
25
28
|
flwr/cli/run/run.py,sha256=qxXgShEXHONx-Gjpl515HF60QzRA-Ygpj2sbl0bZUAA,2331
|
|
@@ -51,7 +54,7 @@ flwr/client/numpy_client.py,sha256=u76GWAdHmJM88Agm2EgLQSvO8Jnk225mJTk-_TmPjFE,1
|
|
|
51
54
|
flwr/client/rest_client/__init__.py,sha256=ThwOnkMdzxo_UuyTI47Q7y9oSpuTgNT2OuFvJCfuDiw,735
|
|
52
55
|
flwr/client/rest_client/connection.py,sha256=ZxTFVDXlONqKTX6uYgxshoEWqzqVcQ8QQ2hKS93oLM8,11302
|
|
53
56
|
flwr/client/supernode/__init__.py,sha256=D5swXxemuRbA2rB_T9B8LwJW-_PucXwmlFQQerwIUv0,793
|
|
54
|
-
flwr/client/supernode/app.py,sha256=
|
|
57
|
+
flwr/client/supernode/app.py,sha256=gauvN8elkIy0vuT0GxT7MmkuBRY74ckZfpxejE7dduM,3861
|
|
55
58
|
flwr/client/typing.py,sha256=c9EvjlEjasxn1Wqx6bGl6Xg6vM1gMFfmXht-E2i5J-k,1006
|
|
56
59
|
flwr/common/__init__.py,sha256=dHOptgKxna78CEQLD5Yu0QIsoSgpIIw5AhIUZCHDWAU,3721
|
|
57
60
|
flwr/common/address.py,sha256=iTAN9jtmIGMrWFnx9XZQl45ZEtQJVZZLYPRBSNVARGI,1882
|
|
@@ -124,16 +127,16 @@ flwr/server/app.py,sha256=FriloRrkDHTlB5G7EBn6sH4v5GhiYFf_ZhbdROgjKbY,24199
|
|
|
124
127
|
flwr/server/client_manager.py,sha256=T8UDSRJBVD3fyIDI7NTAA-NA7GPrMNNgH2OAF54RRxE,6127
|
|
125
128
|
flwr/server/client_proxy.py,sha256=4G-oTwhb45sfWLx2uZdcXD98IZwdTS6F88xe3akCdUg,2399
|
|
126
129
|
flwr/server/compat/__init__.py,sha256=VxnJtJyOjNFQXMNi9hIuzNlZM5n0Hj1p3aq_Pm2udw4,892
|
|
127
|
-
flwr/server/compat/app.py,sha256=
|
|
130
|
+
flwr/server/compat/app.py,sha256=BhF3DySbvKkOIyNXnB1rwZhw8cC8yK_w91Fku8HmC_w,5287
|
|
128
131
|
flwr/server/compat/app_utils.py,sha256=S-M4sGIiZPXXgKFLjlbFP2yN7d-oIj6DaiJNPIZ2z3A,3503
|
|
129
|
-
flwr/server/compat/driver_client_proxy.py,sha256=
|
|
132
|
+
flwr/server/compat/driver_client_proxy.py,sha256=5XWroBrtA8MrQ5xQjgsju5RauMxNPshYLS_EtONEL1I,7370
|
|
130
133
|
flwr/server/compat/legacy_context.py,sha256=D2s7PvQoDnTexuRmf1uG9Von7GUj4Qqyr7qLklSlKAM,1766
|
|
131
134
|
flwr/server/criterion.py,sha256=ypbAexbztzGUxNen9RCHF91QeqiEQix4t4Ih3E-42MM,1061
|
|
132
|
-
flwr/server/driver/__init__.py,sha256=
|
|
133
|
-
flwr/server/driver/
|
|
134
|
-
flwr/server/driver/
|
|
135
|
+
flwr/server/driver/__init__.py,sha256=bbVL5pyA0Y2HcUK4s5U0B4epI-BuUFyEJbchew_8tJY,862
|
|
136
|
+
flwr/server/driver/driver.py,sha256=t9SSSDlo9wT_y2Nl7waGYMTm2VlkvK3_bOb7ggPPlho,5090
|
|
137
|
+
flwr/server/driver/grpc_driver.py,sha256=U5zfI3uYPUBaoOe4JI32t3dvCoSDacZ6EE0g9B8tKbU,11418
|
|
135
138
|
flwr/server/history.py,sha256=hDsoBaA4kUa6d1yvDVXuLluBqOBKSm0_fVDtUtYJkmg,5121
|
|
136
|
-
flwr/server/run_serverapp.py,sha256=
|
|
139
|
+
flwr/server/run_serverapp.py,sha256=3FqKVdFJ280dOVQQ63fu3kL7yNg_4ggtx2H7ljSBT1c,5604
|
|
137
140
|
flwr/server/server.py,sha256=UnBRlI6AGTj0nKeRtEQ3IalM3TJmggMKXhDyn8yKZNk,17664
|
|
138
141
|
flwr/server/server_app.py,sha256=KgAT_HqsfseTLNnfX2ph42PBbVqQ0lFzvYrT90V34y0,4402
|
|
139
142
|
flwr/server/server_config.py,sha256=CZaHVAsMvGLjpWVcLPkiYxgJN4xfIyAiUrCI3fETKY4,1349
|
|
@@ -204,9 +207,9 @@ flwr/simulation/ray_transport/__init__.py,sha256=FsaAnzC4cw4DqoouBCix6496k29jACk
|
|
|
204
207
|
flwr/simulation/ray_transport/ray_actor.py,sha256=_wv2eP7qxkCZ-6rMyYWnjLrGPBZRxjvTPjaVk8zIaQ4,19367
|
|
205
208
|
flwr/simulation/ray_transport/ray_client_proxy.py,sha256=oDu4sEPIOu39vrNi-fqDAe10xtNUXMO49bM2RWfRcyw,6738
|
|
206
209
|
flwr/simulation/ray_transport/utils.py,sha256=TYdtfg1P9VfTdLMOJlifInGpxWHYs9UfUqIv2wfkRLA,2392
|
|
207
|
-
flwr/simulation/run_simulation.py,sha256=
|
|
208
|
-
flwr_nightly-1.9.0.
|
|
209
|
-
flwr_nightly-1.9.0.
|
|
210
|
-
flwr_nightly-1.9.0.
|
|
211
|
-
flwr_nightly-1.9.0.
|
|
212
|
-
flwr_nightly-1.9.0.
|
|
210
|
+
flwr/simulation/run_simulation.py,sha256=nxXNv3r8ODImd5o6f0sa_w5L0I08LD2Udw2OTXStRnQ,15694
|
|
211
|
+
flwr_nightly-1.9.0.dev20240423.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
212
|
+
flwr_nightly-1.9.0.dev20240423.dist-info/METADATA,sha256=zdE6sLfyJNTW7D0GQYAswEN0TE1pJUSzVFZ_KgNmWYk,15260
|
|
213
|
+
flwr_nightly-1.9.0.dev20240423.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
|
|
214
|
+
flwr_nightly-1.9.0.dev20240423.dist-info/entry_points.txt,sha256=DBrrf685V2W9NbbchQwvuqBEpj5ik8tMZNoZg_W2bZY,363
|
|
215
|
+
flwr_nightly-1.9.0.dev20240423.dist-info/RECORD,,
|
flwr/server/driver/abc_driver.py
DELETED
|
@@ -1,140 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
"""Driver (abstract base class)."""
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
from abc import ABC, abstractmethod
|
|
19
|
-
from typing import Iterable, List, Optional
|
|
20
|
-
|
|
21
|
-
from flwr.common import Message, RecordSet
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class Driver(ABC):
|
|
25
|
-
"""Abstract base Driver class for the Driver API."""
|
|
26
|
-
|
|
27
|
-
@abstractmethod
|
|
28
|
-
def create_message( # pylint: disable=too-many-arguments
|
|
29
|
-
self,
|
|
30
|
-
content: RecordSet,
|
|
31
|
-
message_type: str,
|
|
32
|
-
dst_node_id: int,
|
|
33
|
-
group_id: str,
|
|
34
|
-
ttl: Optional[float] = None,
|
|
35
|
-
) -> Message:
|
|
36
|
-
"""Create a new message with specified parameters.
|
|
37
|
-
|
|
38
|
-
This method constructs a new `Message` with given content and metadata.
|
|
39
|
-
The `run_id` and `src_node_id` will be set automatically.
|
|
40
|
-
|
|
41
|
-
Parameters
|
|
42
|
-
----------
|
|
43
|
-
content : RecordSet
|
|
44
|
-
The content for the new message. This holds records that are to be sent
|
|
45
|
-
to the destination node.
|
|
46
|
-
message_type : str
|
|
47
|
-
The type of the message, defining the action to be executed on
|
|
48
|
-
the receiving end.
|
|
49
|
-
dst_node_id : int
|
|
50
|
-
The ID of the destination node to which the message is being sent.
|
|
51
|
-
group_id : str
|
|
52
|
-
The ID of the group to which this message is associated. In some settings,
|
|
53
|
-
this is used as the FL round.
|
|
54
|
-
ttl : Optional[float] (default: None)
|
|
55
|
-
Time-to-live for the round trip of this message, i.e., the time from sending
|
|
56
|
-
this message to receiving a reply. It specifies in seconds the duration for
|
|
57
|
-
which the message and its potential reply are considered valid. If unset,
|
|
58
|
-
the default TTL (i.e., `common.DEFAULT_TTL`) will be used.
|
|
59
|
-
|
|
60
|
-
Returns
|
|
61
|
-
-------
|
|
62
|
-
message : Message
|
|
63
|
-
A new `Message` instance with the specified content and metadata.
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
@abstractmethod
|
|
67
|
-
def get_node_ids(self) -> List[int]:
|
|
68
|
-
"""Get node IDs."""
|
|
69
|
-
|
|
70
|
-
@abstractmethod
|
|
71
|
-
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
72
|
-
"""Push messages to specified node IDs.
|
|
73
|
-
|
|
74
|
-
This method takes an iterable of messages and sends each message
|
|
75
|
-
to the node specified in `dst_node_id`.
|
|
76
|
-
|
|
77
|
-
Parameters
|
|
78
|
-
----------
|
|
79
|
-
messages : Iterable[Message]
|
|
80
|
-
An iterable of messages to be sent.
|
|
81
|
-
|
|
82
|
-
Returns
|
|
83
|
-
-------
|
|
84
|
-
message_ids : Iterable[str]
|
|
85
|
-
An iterable of IDs for the messages that were sent, which can be used
|
|
86
|
-
to pull replies.
|
|
87
|
-
"""
|
|
88
|
-
|
|
89
|
-
@abstractmethod
|
|
90
|
-
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
91
|
-
"""Pull messages based on message IDs.
|
|
92
|
-
|
|
93
|
-
This method is used to collect messages from the SuperLink
|
|
94
|
-
that correspond to a set of given message IDs.
|
|
95
|
-
|
|
96
|
-
Parameters
|
|
97
|
-
----------
|
|
98
|
-
message_ids : Iterable[str]
|
|
99
|
-
An iterable of message IDs for which reply messages are to be retrieved.
|
|
100
|
-
|
|
101
|
-
Returns
|
|
102
|
-
-------
|
|
103
|
-
messages : Iterable[Message]
|
|
104
|
-
An iterable of messages received.
|
|
105
|
-
"""
|
|
106
|
-
|
|
107
|
-
@abstractmethod
|
|
108
|
-
def send_and_receive(
|
|
109
|
-
self,
|
|
110
|
-
messages: Iterable[Message],
|
|
111
|
-
*,
|
|
112
|
-
timeout: Optional[float] = None,
|
|
113
|
-
) -> Iterable[Message]:
|
|
114
|
-
"""Push messages to specified node IDs and pull the reply messages.
|
|
115
|
-
|
|
116
|
-
This method sends a list of messages to their destination node IDs and then
|
|
117
|
-
waits for the replies. It continues to pull replies until either all
|
|
118
|
-
replies are received or the specified timeout duration is exceeded.
|
|
119
|
-
|
|
120
|
-
Parameters
|
|
121
|
-
----------
|
|
122
|
-
messages : Iterable[Message]
|
|
123
|
-
An iterable of messages to be sent.
|
|
124
|
-
timeout : Optional[float] (default: None)
|
|
125
|
-
The timeout duration in seconds. If specified, the method will wait for
|
|
126
|
-
replies for this duration. If `None`, there is no time limit and the method
|
|
127
|
-
will wait until replies for all messages are received.
|
|
128
|
-
|
|
129
|
-
Returns
|
|
130
|
-
-------
|
|
131
|
-
replies : Iterable[Message]
|
|
132
|
-
An iterable of reply messages received from the SuperLink.
|
|
133
|
-
|
|
134
|
-
Notes
|
|
135
|
-
-----
|
|
136
|
-
This method uses `push_messages` to send the messages and `pull_messages`
|
|
137
|
-
to collect the replies. If `timeout` is set, the method may not return
|
|
138
|
-
replies for all sent messages. A message remains valid until its TTL,
|
|
139
|
-
which is not affected by `timeout`.
|
|
140
|
-
"""
|
{flwr_nightly-1.9.0.dev20240420.dist-info → flwr_nightly-1.9.0.dev20240423.dist-info}/LICENSE
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|