flwr-nightly 1.13.0.dev20241021__py3-none-any.whl → 1.13.0.dev20241106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +2 -2
- flwr/cli/config_utils.py +97 -0
- flwr/cli/log.py +63 -97
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +18 -83
- flwr/client/app.py +13 -14
- flwr/client/clientapp/app.py +1 -2
- flwr/client/{node_state.py → run_info_store.py} +4 -3
- flwr/client/supernode/app.py +6 -8
- flwr/common/constant.py +39 -4
- flwr/common/context.py +9 -4
- flwr/common/date.py +3 -3
- flwr/common/logger.py +103 -0
- flwr/common/serde.py +24 -0
- flwr/common/telemetry.py +0 -6
- flwr/common/typing.py +9 -0
- flwr/proto/exec_pb2.py +6 -6
- flwr/proto/exec_pb2.pyi +8 -2
- flwr/proto/log_pb2.py +29 -0
- flwr/proto/log_pb2.pyi +39 -0
- flwr/proto/log_pb2_grpc.py +4 -0
- flwr/proto/log_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +8 -8
- flwr/proto/message_pb2.pyi +4 -1
- flwr/proto/serverappio_pb2.py +52 -0
- flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
- flwr/proto/serverappio_pb2_grpc.py +376 -0
- flwr/proto/serverappio_pb2_grpc.pyi +147 -0
- flwr/proto/simulationio_pb2.py +38 -0
- flwr/proto/simulationio_pb2.pyi +65 -0
- flwr/proto/simulationio_pb2_grpc.py +171 -0
- flwr/proto/simulationio_pb2_grpc.pyi +68 -0
- flwr/server/app.py +247 -105
- flwr/server/driver/driver.py +15 -1
- flwr/server/driver/grpc_driver.py +26 -33
- flwr/server/driver/inmemory_driver.py +6 -14
- flwr/server/run_serverapp.py +29 -23
- flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
- flwr/server/serverapp/app.py +270 -0
- flwr/server/strategy/fedadam.py +11 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
- flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +7 -7
- flwr/server/superlink/fleet/vce/vce_api.py +23 -23
- flwr/server/superlink/linkstate/__init__.py +28 -0
- flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +180 -21
- flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +144 -15
- flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
- flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +300 -50
- flwr/server/superlink/{state → linkstate}/utils.py +84 -2
- flwr/server/superlink/simulation/__init__.py +15 -0
- flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
- flwr/server/superlink/simulation/simulationio_servicer.py +132 -0
- flwr/simulation/__init__.py +2 -0
- flwr/simulation/app.py +1 -1
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- flwr/simulation/run_simulation.py +57 -131
- flwr/simulation/simulationio_connection.py +86 -0
- flwr/superexec/app.py +6 -134
- flwr/superexec/deployment.py +60 -65
- flwr/superexec/exec_grpc.py +15 -8
- flwr/superexec/exec_servicer.py +34 -63
- flwr/superexec/executor.py +22 -4
- flwr/superexec/simulation.py +13 -8
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/METADATA +1 -1
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/RECORD +77 -64
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/entry_points.txt +1 -0
- flwr/client/node_state_tests.py +0 -66
- flwr/proto/driver_pb2.py +0 -42
- flwr/proto/driver_pb2_grpc.py +0 -239
- flwr/proto/driver_pb2_grpc.pyi +0 -94
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.13.0.dev20241021.dist-info → flwr_nightly-1.13.0.dev20241106.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower ServerApp process."""
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import sys
|
|
19
|
+
from logging import DEBUG, ERROR, INFO, WARN
|
|
20
|
+
from os.path import isfile
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from queue import Queue
|
|
23
|
+
from time import sleep
|
|
24
|
+
from typing import Optional
|
|
25
|
+
|
|
26
|
+
from flwr.cli.config_utils import get_fab_metadata
|
|
27
|
+
from flwr.cli.install import install_from_fab
|
|
28
|
+
from flwr.common.config import (
|
|
29
|
+
get_flwr_dir,
|
|
30
|
+
get_fused_config_from_dir,
|
|
31
|
+
get_project_config,
|
|
32
|
+
get_project_dir,
|
|
33
|
+
)
|
|
34
|
+
from flwr.common.constant import Status, SubStatus
|
|
35
|
+
from flwr.common.logger import (
|
|
36
|
+
log,
|
|
37
|
+
mirror_output_to_queue,
|
|
38
|
+
restore_output,
|
|
39
|
+
start_log_uploader,
|
|
40
|
+
stop_log_uploader,
|
|
41
|
+
)
|
|
42
|
+
from flwr.common.serde import (
|
|
43
|
+
context_from_proto,
|
|
44
|
+
context_to_proto,
|
|
45
|
+
fab_from_proto,
|
|
46
|
+
run_from_proto,
|
|
47
|
+
run_status_to_proto,
|
|
48
|
+
)
|
|
49
|
+
from flwr.common.typing import RunStatus
|
|
50
|
+
from flwr.proto.run_pb2 import UpdateRunStatusRequest # pylint: disable=E0611
|
|
51
|
+
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
52
|
+
PullServerAppInputsRequest,
|
|
53
|
+
PullServerAppInputsResponse,
|
|
54
|
+
PushServerAppOutputsRequest,
|
|
55
|
+
)
|
|
56
|
+
from flwr.server.driver.grpc_driver import GrpcDriver
|
|
57
|
+
from flwr.server.run_serverapp import run as run_
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def flwr_serverapp() -> None:
|
|
61
|
+
"""Run process-isolated Flower ServerApp."""
|
|
62
|
+
# Capture stdout/stderr
|
|
63
|
+
log_queue: Queue[Optional[str]] = Queue()
|
|
64
|
+
mirror_output_to_queue(log_queue)
|
|
65
|
+
|
|
66
|
+
parser = argparse.ArgumentParser(
|
|
67
|
+
description="Run a Flower ServerApp",
|
|
68
|
+
)
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
"--superlink",
|
|
71
|
+
type=str,
|
|
72
|
+
help="Address of SuperLink's DriverAPI",
|
|
73
|
+
)
|
|
74
|
+
parser.add_argument(
|
|
75
|
+
"--run-once",
|
|
76
|
+
action="store_true",
|
|
77
|
+
help="When set, this process will start a single ServerApp "
|
|
78
|
+
"for a pending Run. If no pending run the process will exit. ",
|
|
79
|
+
)
|
|
80
|
+
parser.add_argument(
|
|
81
|
+
"--flwr-dir",
|
|
82
|
+
default=None,
|
|
83
|
+
help="""The path containing installed Flower Apps.
|
|
84
|
+
By default, this value is equal to:
|
|
85
|
+
|
|
86
|
+
- `$FLWR_HOME/` if `$FLWR_HOME` is defined
|
|
87
|
+
- `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
|
|
88
|
+
- `$HOME/.flwr/` in all other cases
|
|
89
|
+
""",
|
|
90
|
+
)
|
|
91
|
+
parser.add_argument(
|
|
92
|
+
"--insecure",
|
|
93
|
+
action="store_true",
|
|
94
|
+
help="Run the server without HTTPS, regardless of whether certificate "
|
|
95
|
+
"paths are provided. By default, the server runs with HTTPS enabled. "
|
|
96
|
+
"Use this flag only if you understand the risks.",
|
|
97
|
+
)
|
|
98
|
+
parser.add_argument(
|
|
99
|
+
"--root-certificates",
|
|
100
|
+
metavar="ROOT_CERT",
|
|
101
|
+
type=str,
|
|
102
|
+
help="Specifies the path to the PEM-encoded root certificate file for "
|
|
103
|
+
"establishing secure HTTPS connections.",
|
|
104
|
+
)
|
|
105
|
+
args = parser.parse_args()
|
|
106
|
+
|
|
107
|
+
log(INFO, "Starting Flower ServerApp")
|
|
108
|
+
certificates = _try_obtain_certificates(args)
|
|
109
|
+
|
|
110
|
+
log(
|
|
111
|
+
DEBUG,
|
|
112
|
+
"Staring isolated `ServerApp` connected to SuperLink DriverAPI at %s",
|
|
113
|
+
args.superlink,
|
|
114
|
+
)
|
|
115
|
+
run_serverapp(
|
|
116
|
+
superlink=args.superlink,
|
|
117
|
+
log_queue=log_queue,
|
|
118
|
+
run_once=args.run_once,
|
|
119
|
+
flwr_dir_=args.flwr_dir,
|
|
120
|
+
certificates=certificates,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Restore stdout/stderr
|
|
124
|
+
restore_output()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _try_obtain_certificates(
|
|
128
|
+
args: argparse.Namespace,
|
|
129
|
+
) -> Optional[bytes]:
|
|
130
|
+
|
|
131
|
+
if args.insecure:
|
|
132
|
+
if args.root_certificates is not None:
|
|
133
|
+
sys.exit(
|
|
134
|
+
"Conflicting options: The '--insecure' flag disables HTTPS, "
|
|
135
|
+
"but '--root-certificates' was also specified. Please remove "
|
|
136
|
+
"the '--root-certificates' option when running in insecure mode, "
|
|
137
|
+
"or omit '--insecure' to use HTTPS."
|
|
138
|
+
)
|
|
139
|
+
log(
|
|
140
|
+
WARN,
|
|
141
|
+
"Option `--insecure` was set. Starting insecure HTTP channel to %s.",
|
|
142
|
+
args.superlink,
|
|
143
|
+
)
|
|
144
|
+
root_certificates = None
|
|
145
|
+
else:
|
|
146
|
+
# Load the certificates if provided, or load the system certificates
|
|
147
|
+
if not isfile(args.root_certificates):
|
|
148
|
+
sys.exit("Path argument `--root-certificates` does not point to a file.")
|
|
149
|
+
root_certificates = Path(args.root_certificates).read_bytes()
|
|
150
|
+
log(
|
|
151
|
+
DEBUG,
|
|
152
|
+
"Starting secure HTTPS channel to %s "
|
|
153
|
+
"with the following certificates: %s.",
|
|
154
|
+
args.superlink,
|
|
155
|
+
args.root_certificates,
|
|
156
|
+
)
|
|
157
|
+
return root_certificates
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def run_serverapp( # pylint: disable=R0914, disable=W0212
|
|
161
|
+
superlink: str,
|
|
162
|
+
log_queue: Queue[Optional[str]],
|
|
163
|
+
run_once: bool,
|
|
164
|
+
flwr_dir_: Optional[str] = None,
|
|
165
|
+
certificates: Optional[bytes] = None,
|
|
166
|
+
) -> None:
|
|
167
|
+
"""Run Flower ServerApp process."""
|
|
168
|
+
driver = GrpcDriver(
|
|
169
|
+
serverappio_service_address=superlink,
|
|
170
|
+
root_certificates=certificates,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Resolve directory where FABs are installed
|
|
174
|
+
flwr_dir = get_flwr_dir(flwr_dir_)
|
|
175
|
+
log_uploader = None
|
|
176
|
+
|
|
177
|
+
while True:
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
# Pull ServerAppInputs from LinkState
|
|
181
|
+
req = PullServerAppInputsRequest()
|
|
182
|
+
res: PullServerAppInputsResponse = driver._stub.PullServerAppInputs(req)
|
|
183
|
+
if not res.HasField("run"):
|
|
184
|
+
sleep(3)
|
|
185
|
+
run_status = None
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
context = context_from_proto(res.context)
|
|
189
|
+
run = run_from_proto(res.run)
|
|
190
|
+
fab = fab_from_proto(res.fab)
|
|
191
|
+
|
|
192
|
+
driver.init_run(run.run_id)
|
|
193
|
+
|
|
194
|
+
# Start log uploader for this run
|
|
195
|
+
log_uploader = start_log_uploader(
|
|
196
|
+
log_queue=log_queue,
|
|
197
|
+
node_id=0,
|
|
198
|
+
run_id=run.run_id,
|
|
199
|
+
stub=driver._stub,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
log(DEBUG, "ServerApp process starts FAB installation.")
|
|
203
|
+
install_from_fab(fab.content, flwr_dir=flwr_dir, skip_prompt=True)
|
|
204
|
+
|
|
205
|
+
fab_id, fab_version = get_fab_metadata(fab.content)
|
|
206
|
+
|
|
207
|
+
app_path = str(get_project_dir(fab_id, fab_version, fab.hash_str, flwr_dir))
|
|
208
|
+
config = get_project_config(app_path)
|
|
209
|
+
|
|
210
|
+
# Obtain server app reference and the run config
|
|
211
|
+
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
|
|
212
|
+
server_app_run_config = get_fused_config_from_dir(
|
|
213
|
+
Path(app_path), run.override_config
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Update run_config in context
|
|
217
|
+
context.run_config = server_app_run_config
|
|
218
|
+
|
|
219
|
+
log(
|
|
220
|
+
DEBUG,
|
|
221
|
+
"Flower will load ServerApp `%s` in %s",
|
|
222
|
+
server_app_attr,
|
|
223
|
+
app_path,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Change status to Running
|
|
227
|
+
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
|
|
228
|
+
driver._stub.UpdateRunStatus(
|
|
229
|
+
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Load and run the ServerApp with the Driver
|
|
233
|
+
updated_context = run_(
|
|
234
|
+
driver=driver,
|
|
235
|
+
server_app_dir=app_path,
|
|
236
|
+
server_app_attr=server_app_attr,
|
|
237
|
+
context=context,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Send resulting context
|
|
241
|
+
context_proto = context_to_proto(updated_context)
|
|
242
|
+
out_req = PushServerAppOutputsRequest(
|
|
243
|
+
run_id=run.run_id, context=context_proto
|
|
244
|
+
)
|
|
245
|
+
_ = driver._stub.PushServerAppOutputs(out_req)
|
|
246
|
+
|
|
247
|
+
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
|
|
248
|
+
|
|
249
|
+
except Exception as ex: # pylint: disable=broad-exception-caught
|
|
250
|
+
exc_entity = "ServerApp"
|
|
251
|
+
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
|
252
|
+
run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
|
|
253
|
+
|
|
254
|
+
finally:
|
|
255
|
+
if run_status:
|
|
256
|
+
run_status_proto = run_status_to_proto(run_status)
|
|
257
|
+
driver._stub.UpdateRunStatus(
|
|
258
|
+
UpdateRunStatusRequest(
|
|
259
|
+
run_id=run.run_id, run_status=run_status_proto
|
|
260
|
+
)
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# Stop log uploader for this run
|
|
264
|
+
if log_uploader:
|
|
265
|
+
stop_log_uploader(log_queue, log_uploader)
|
|
266
|
+
log_uploader = None
|
|
267
|
+
|
|
268
|
+
# Stop the loop if `flwr-serverapp` is expected to process a single run
|
|
269
|
+
if run_once:
|
|
270
|
+
break
|
flwr/server/strategy/fedadam.py
CHANGED
|
@@ -170,8 +170,18 @@ class FedAdam(FedOpt):
|
|
|
170
170
|
for x, y in zip(self.v_t, delta_t)
|
|
171
171
|
]
|
|
172
172
|
|
|
173
|
+
# Compute the bias-corrected learning rate, `eta_norm` for improving convergence
|
|
174
|
+
# in the early rounds of FL training. This `eta_norm` is `\alpha_t` in Kingma &
|
|
175
|
+
# Ba, 2014 (http://arxiv.org/abs/1412.6980) "Adam: A Method for Stochastic
|
|
176
|
+
# Optimization" in the formula line right before Section 2.1.
|
|
177
|
+
eta_norm = (
|
|
178
|
+
self.eta
|
|
179
|
+
* np.sqrt(1 - np.power(self.beta_2, server_round + 1.0))
|
|
180
|
+
/ (1 - np.power(self.beta_1, server_round + 1.0))
|
|
181
|
+
)
|
|
182
|
+
|
|
173
183
|
new_weights = [
|
|
174
|
-
x +
|
|
184
|
+
x + eta_norm * y / (np.sqrt(z) + self.tau)
|
|
175
185
|
for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
|
|
176
186
|
]
|
|
177
187
|
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""ServerAppIo gRPC API."""
|
|
16
16
|
|
|
17
17
|
from logging import INFO
|
|
18
18
|
from typing import Optional
|
|
@@ -21,37 +21,40 @@ import grpc
|
|
|
21
21
|
|
|
22
22
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
23
23
|
from flwr.common.logger import log
|
|
24
|
-
from flwr.proto.
|
|
25
|
-
|
|
24
|
+
from flwr.proto.serverappio_pb2_grpc import ( # pylint: disable=E0611
|
|
25
|
+
add_ServerAppIoServicer_to_server,
|
|
26
26
|
)
|
|
27
27
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
28
|
-
from flwr.server.superlink.
|
|
28
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
29
29
|
|
|
30
30
|
from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
|
31
|
-
from .
|
|
31
|
+
from .serverappio_servicer import ServerAppIoServicer
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
def
|
|
34
|
+
def run_serverappio_api_grpc(
|
|
35
35
|
address: str,
|
|
36
|
-
state_factory:
|
|
36
|
+
state_factory: LinkStateFactory,
|
|
37
37
|
ffs_factory: FfsFactory,
|
|
38
38
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
39
39
|
) -> grpc.Server:
|
|
40
|
-
"""Run
|
|
41
|
-
# Create
|
|
42
|
-
|
|
40
|
+
"""Run ServerAppIo API (gRPC, request-response)."""
|
|
41
|
+
# Create ServerAppIo API gRPC server
|
|
42
|
+
serverappio_servicer: grpc.Server = ServerAppIoServicer(
|
|
43
43
|
state_factory=state_factory,
|
|
44
44
|
ffs_factory=ffs_factory,
|
|
45
45
|
)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
servicer_and_add_fn=(
|
|
46
|
+
serverappio_add_servicer_to_server_fn = add_ServerAppIoServicer_to_server
|
|
47
|
+
serverappio_grpc_server = generic_create_grpc_server(
|
|
48
|
+
servicer_and_add_fn=(
|
|
49
|
+
serverappio_servicer,
|
|
50
|
+
serverappio_add_servicer_to_server_fn,
|
|
51
|
+
),
|
|
49
52
|
server_address=address,
|
|
50
53
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
51
54
|
certificates=certificates,
|
|
52
55
|
)
|
|
53
56
|
|
|
54
|
-
log(INFO, "Flower ECE: Starting
|
|
55
|
-
|
|
57
|
+
log(INFO, "Flower ECE: Starting ServerAppIo API (gRPC-rere) on %s", address)
|
|
58
|
+
serverappio_grpc_server.start()
|
|
56
59
|
|
|
57
|
-
return
|
|
60
|
+
return serverappio_grpc_server
|
|
@@ -12,62 +12,80 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
"""
|
|
15
|
+
"""ServerAppIo API servicer."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
import threading
|
|
18
19
|
import time
|
|
19
|
-
from logging import DEBUG
|
|
20
|
+
from logging import DEBUG, INFO
|
|
20
21
|
from typing import Optional
|
|
21
22
|
from uuid import UUID
|
|
22
23
|
|
|
23
24
|
import grpc
|
|
24
25
|
|
|
26
|
+
from flwr.common import ConfigsRecord
|
|
27
|
+
from flwr.common.constant import Status
|
|
25
28
|
from flwr.common.logger import log
|
|
26
29
|
from flwr.common.serde import (
|
|
30
|
+
context_from_proto,
|
|
31
|
+
context_to_proto,
|
|
27
32
|
fab_from_proto,
|
|
28
33
|
fab_to_proto,
|
|
34
|
+
run_status_from_proto,
|
|
35
|
+
run_to_proto,
|
|
29
36
|
user_config_from_proto,
|
|
30
|
-
user_config_to_proto,
|
|
31
|
-
)
|
|
32
|
-
from flwr.common.typing import Fab
|
|
33
|
-
from flwr.proto import driver_pb2_grpc # pylint: disable=E0611
|
|
34
|
-
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
35
|
-
GetNodesRequest,
|
|
36
|
-
GetNodesResponse,
|
|
37
|
-
PullTaskResRequest,
|
|
38
|
-
PullTaskResResponse,
|
|
39
|
-
PushTaskInsRequest,
|
|
40
|
-
PushTaskInsResponse,
|
|
41
37
|
)
|
|
38
|
+
from flwr.common.typing import Fab, RunStatus
|
|
39
|
+
from flwr.proto import serverappio_pb2_grpc # pylint: disable=E0611
|
|
42
40
|
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
41
|
+
from flwr.proto.log_pb2 import ( # pylint: disable=E0611
|
|
42
|
+
PushLogsRequest,
|
|
43
|
+
PushLogsResponse,
|
|
44
|
+
)
|
|
43
45
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
44
46
|
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
45
47
|
CreateRunRequest,
|
|
46
48
|
CreateRunResponse,
|
|
47
49
|
GetRunRequest,
|
|
48
50
|
GetRunResponse,
|
|
49
|
-
|
|
51
|
+
UpdateRunStatusRequest,
|
|
52
|
+
UpdateRunStatusResponse,
|
|
53
|
+
)
|
|
54
|
+
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
55
|
+
GetNodesRequest,
|
|
56
|
+
GetNodesResponse,
|
|
57
|
+
PullServerAppInputsRequest,
|
|
58
|
+
PullServerAppInputsResponse,
|
|
59
|
+
PullTaskResRequest,
|
|
60
|
+
PullTaskResResponse,
|
|
61
|
+
PushServerAppOutputsRequest,
|
|
62
|
+
PushServerAppOutputsResponse,
|
|
63
|
+
PushTaskInsRequest,
|
|
64
|
+
PushTaskInsResponse,
|
|
50
65
|
)
|
|
51
66
|
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
52
67
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
53
68
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
54
|
-
from flwr.server.superlink.
|
|
69
|
+
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
|
|
55
70
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
56
71
|
|
|
57
72
|
|
|
58
|
-
class
|
|
59
|
-
"""
|
|
73
|
+
class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
|
|
74
|
+
"""ServerAppIo API servicer."""
|
|
60
75
|
|
|
61
|
-
def __init__(
|
|
76
|
+
def __init__(
|
|
77
|
+
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
78
|
+
) -> None:
|
|
62
79
|
self.state_factory = state_factory
|
|
63
80
|
self.ffs_factory = ffs_factory
|
|
81
|
+
self.lock = threading.RLock()
|
|
64
82
|
|
|
65
83
|
def GetNodes(
|
|
66
84
|
self, request: GetNodesRequest, context: grpc.ServicerContext
|
|
67
85
|
) -> GetNodesResponse:
|
|
68
86
|
"""Get available nodes."""
|
|
69
|
-
log(DEBUG, "
|
|
70
|
-
state:
|
|
87
|
+
log(DEBUG, "ServerAppIoServicer.GetNodes")
|
|
88
|
+
state: LinkState = self.state_factory.state()
|
|
71
89
|
all_ids: set[int] = state.get_nodes(request.run_id)
|
|
72
90
|
nodes: list[Node] = [
|
|
73
91
|
Node(node_id=node_id, anonymous=False) for node_id in all_ids
|
|
@@ -78,8 +96,8 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
78
96
|
self, request: CreateRunRequest, context: grpc.ServicerContext
|
|
79
97
|
) -> CreateRunResponse:
|
|
80
98
|
"""Create run ID."""
|
|
81
|
-
log(DEBUG, "
|
|
82
|
-
state:
|
|
99
|
+
log(DEBUG, "ServerAppIoServicer.CreateRun")
|
|
100
|
+
state: LinkState = self.state_factory.state()
|
|
83
101
|
if request.HasField("fab"):
|
|
84
102
|
fab = fab_from_proto(request.fab)
|
|
85
103
|
ffs: Ffs = self.ffs_factory.ffs()
|
|
@@ -95,6 +113,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
95
113
|
request.fab_version,
|
|
96
114
|
fab_hash,
|
|
97
115
|
user_config_from_proto(request.override_config),
|
|
116
|
+
ConfigsRecord(),
|
|
98
117
|
)
|
|
99
118
|
return CreateRunResponse(run_id=run_id)
|
|
100
119
|
|
|
@@ -102,7 +121,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
102
121
|
self, request: PushTaskInsRequest, context: grpc.ServicerContext
|
|
103
122
|
) -> PushTaskInsResponse:
|
|
104
123
|
"""Push a set of TaskIns."""
|
|
105
|
-
log(DEBUG, "
|
|
124
|
+
log(DEBUG, "ServerAppIoServicer.PushTaskIns")
|
|
106
125
|
|
|
107
126
|
# Set pushed_at (timestamp in seconds)
|
|
108
127
|
pushed_at = time.time()
|
|
@@ -116,7 +135,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
116
135
|
_raise_if(bool(validation_errors), ", ".join(validation_errors))
|
|
117
136
|
|
|
118
137
|
# Init state
|
|
119
|
-
state:
|
|
138
|
+
state: LinkState = self.state_factory.state()
|
|
120
139
|
|
|
121
140
|
# Store each TaskIns
|
|
122
141
|
task_ids: list[Optional[UUID]] = []
|
|
@@ -132,17 +151,20 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
132
151
|
self, request: PullTaskResRequest, context: grpc.ServicerContext
|
|
133
152
|
) -> PullTaskResResponse:
|
|
134
153
|
"""Pull a set of TaskRes."""
|
|
135
|
-
log(DEBUG, "
|
|
154
|
+
log(DEBUG, "ServerAppIoServicer.PullTaskRes")
|
|
136
155
|
|
|
137
156
|
# Convert each task_id str to UUID
|
|
138
157
|
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}
|
|
139
158
|
|
|
140
159
|
# Init state
|
|
141
|
-
state:
|
|
160
|
+
state: LinkState = self.state_factory.state()
|
|
142
161
|
|
|
143
162
|
# Register callback
|
|
144
163
|
def on_rpc_done() -> None:
|
|
145
|
-
log(
|
|
164
|
+
log(
|
|
165
|
+
DEBUG,
|
|
166
|
+
"ServerAppIoServicer.PullTaskRes callback: delete TaskIns/TaskRes",
|
|
167
|
+
)
|
|
146
168
|
|
|
147
169
|
if context.is_active():
|
|
148
170
|
return
|
|
@@ -164,10 +186,10 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
164
186
|
self, request: GetRunRequest, context: grpc.ServicerContext
|
|
165
187
|
) -> GetRunResponse:
|
|
166
188
|
"""Get run information."""
|
|
167
|
-
log(DEBUG, "
|
|
189
|
+
log(DEBUG, "ServerAppIoServicer.GetRun")
|
|
168
190
|
|
|
169
191
|
# Init state
|
|
170
|
-
state:
|
|
192
|
+
state: LinkState = self.state_factory.state()
|
|
171
193
|
|
|
172
194
|
# Retrieve run information
|
|
173
195
|
run = state.get_run(request.run_id)
|
|
@@ -175,21 +197,13 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
175
197
|
if run is None:
|
|
176
198
|
return GetRunResponse()
|
|
177
199
|
|
|
178
|
-
return GetRunResponse(
|
|
179
|
-
run=Run(
|
|
180
|
-
run_id=run.run_id,
|
|
181
|
-
fab_id=run.fab_id,
|
|
182
|
-
fab_version=run.fab_version,
|
|
183
|
-
override_config=user_config_to_proto(run.override_config),
|
|
184
|
-
fab_hash=run.fab_hash,
|
|
185
|
-
)
|
|
186
|
-
)
|
|
200
|
+
return GetRunResponse(run=run_to_proto(run))
|
|
187
201
|
|
|
188
202
|
def GetFab(
|
|
189
203
|
self, request: GetFabRequest, context: grpc.ServicerContext
|
|
190
204
|
) -> GetFabResponse:
|
|
191
205
|
"""Get FAB from Ffs."""
|
|
192
|
-
log(DEBUG, "
|
|
206
|
+
log(DEBUG, "ServerAppIoServicer.GetFab")
|
|
193
207
|
|
|
194
208
|
ffs: Ffs = self.ffs_factory.ffs()
|
|
195
209
|
if result := ffs.get(request.hash_str):
|
|
@@ -198,6 +212,78 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
198
212
|
|
|
199
213
|
raise ValueError(f"Found no FAB with hash: {request.hash_str}")
|
|
200
214
|
|
|
215
|
+
def PullServerAppInputs(
|
|
216
|
+
self, request: PullServerAppInputsRequest, context: grpc.ServicerContext
|
|
217
|
+
) -> PullServerAppInputsResponse:
|
|
218
|
+
"""Pull ServerApp process inputs."""
|
|
219
|
+
log(DEBUG, "ServerAppIoServicer.PullServerAppInputs")
|
|
220
|
+
# Init access to LinkState and Ffs
|
|
221
|
+
state = self.state_factory.state()
|
|
222
|
+
ffs = self.ffs_factory.ffs()
|
|
223
|
+
|
|
224
|
+
# Lock access to LinkState, preventing obtaining the same pending run_id
|
|
225
|
+
with self.lock:
|
|
226
|
+
# Attempt getting the run_id of a pending run
|
|
227
|
+
run_id = state.get_pending_run_id()
|
|
228
|
+
# If there's no pending run, return an empty response
|
|
229
|
+
if run_id is None:
|
|
230
|
+
return PullServerAppInputsResponse()
|
|
231
|
+
|
|
232
|
+
# Retrieve Context, Run and Fab for the run_id
|
|
233
|
+
serverapp_ctxt = state.get_serverapp_context(run_id)
|
|
234
|
+
run = state.get_run(run_id)
|
|
235
|
+
fab = None
|
|
236
|
+
if run and run.fab_hash:
|
|
237
|
+
if result := ffs.get(run.fab_hash):
|
|
238
|
+
fab = Fab(run.fab_hash, result[0])
|
|
239
|
+
if run and fab and serverapp_ctxt:
|
|
240
|
+
# Update run status to STARTING
|
|
241
|
+
if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
|
|
242
|
+
log(INFO, "Starting run %d", run_id)
|
|
243
|
+
return PullServerAppInputsResponse(
|
|
244
|
+
context=context_to_proto(serverapp_ctxt),
|
|
245
|
+
run=run_to_proto(run),
|
|
246
|
+
fab=fab_to_proto(fab),
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Raise an exception if the Run or Fab is not found,
|
|
250
|
+
# or if the status cannot be updated to STARTING
|
|
251
|
+
raise RuntimeError(f"Failed to start run {run_id}")
|
|
252
|
+
|
|
253
|
+
def PushServerAppOutputs(
|
|
254
|
+
self, request: PushServerAppOutputsRequest, context: grpc.ServicerContext
|
|
255
|
+
) -> PushServerAppOutputsResponse:
|
|
256
|
+
"""Push ServerApp process outputs."""
|
|
257
|
+
log(DEBUG, "ServerAppIoServicer.PushServerAppOutputs")
|
|
258
|
+
state = self.state_factory.state()
|
|
259
|
+
state.set_serverapp_context(request.run_id, context_from_proto(request.context))
|
|
260
|
+
return PushServerAppOutputsResponse()
|
|
261
|
+
|
|
262
|
+
def UpdateRunStatus(
|
|
263
|
+
self, request: UpdateRunStatusRequest, context: grpc.ServicerContext
|
|
264
|
+
) -> UpdateRunStatusResponse:
|
|
265
|
+
"""Update the status of a run."""
|
|
266
|
+
log(DEBUG, "ControlServicer.UpdateRunStatus")
|
|
267
|
+
state = self.state_factory.state()
|
|
268
|
+
|
|
269
|
+
# Update the run status
|
|
270
|
+
state.update_run_status(
|
|
271
|
+
run_id=request.run_id, new_status=run_status_from_proto(request.run_status)
|
|
272
|
+
)
|
|
273
|
+
return UpdateRunStatusResponse()
|
|
274
|
+
|
|
275
|
+
def PushLogs(
|
|
276
|
+
self, request: PushLogsRequest, context: grpc.ServicerContext
|
|
277
|
+
) -> PushLogsResponse:
|
|
278
|
+
"""Push logs."""
|
|
279
|
+
log(DEBUG, "ServerAppIoServicer.PushLogs")
|
|
280
|
+
state = self.state_factory.state()
|
|
281
|
+
|
|
282
|
+
# Add logs to LinkState
|
|
283
|
+
merged_logs = "".join(request.logs)
|
|
284
|
+
state.add_serverapp_log(request.run_id, merged_logs)
|
|
285
|
+
return PushLogsResponse()
|
|
286
|
+
|
|
201
287
|
|
|
202
288
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
203
289
|
if validation_error:
|
|
@@ -48,7 +48,7 @@ from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
|
48
48
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
49
49
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
50
50
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
51
|
-
from flwr.server.superlink.
|
|
51
|
+
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
52
52
|
|
|
53
53
|
T = TypeVar("T", bound=GrpcMessage)
|
|
54
54
|
|
|
@@ -77,7 +77,9 @@ def _handle(
|
|
|
77
77
|
class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
|
|
78
78
|
"""Fleet API via GrpcAdapter servicer."""
|
|
79
79
|
|
|
80
|
-
def __init__(
|
|
80
|
+
def __init__(
|
|
81
|
+
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
|
|
82
|
+
) -> None:
|
|
81
83
|
self.state_factory = state_factory
|
|
82
84
|
self.ffs_factory = ffs_factory
|
|
83
85
|
|
|
@@ -30,7 +30,7 @@ from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611
|
|
|
30
30
|
add_FlowerServiceServicer_to_server,
|
|
31
31
|
)
|
|
32
32
|
from flwr.server.client_manager import ClientManager
|
|
33
|
-
from flwr.server.superlink.driver.
|
|
33
|
+
from flwr.server.superlink.driver.serverappio_servicer import ServerAppIoServicer
|
|
34
34
|
from flwr.server.superlink.fleet.grpc_adapter.grpc_adapter_servicer import (
|
|
35
35
|
GrpcAdapterServicer,
|
|
36
36
|
)
|
|
@@ -161,7 +161,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments,R0917
|
|
|
161
161
|
tuple[FleetServicer, AddServicerToServerFn],
|
|
162
162
|
tuple[GrpcAdapterServicer, AddServicerToServerFn],
|
|
163
163
|
tuple[FlowerServiceServicer, AddServicerToServerFn],
|
|
164
|
-
tuple[
|
|
164
|
+
tuple[ServerAppIoServicer, AddServicerToServerFn],
|
|
165
165
|
],
|
|
166
166
|
server_address: str,
|
|
167
167
|
max_concurrent_workers: int = 1000,
|