flwr-nightly 1.11.0.dev20240815__py3-none-any.whl → 1.11.0.dev20240816__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/config_utils.py +2 -2
- flwr/cli/run/run.py +5 -4
- flwr/client/app.py +96 -11
- flwr/client/grpc_rere_client/connection.py +8 -1
- flwr/client/process/clientappio_servicer.py +5 -6
- flwr/client/process/process.py +143 -0
- flwr/client/process/utils.py +108 -0
- flwr/client/rest_client/connection.py +14 -2
- flwr/client/supernode/app.py +25 -97
- flwr/server/app.py +3 -0
- flwr/server/run_serverapp.py +18 -2
- flwr/server/server.py +3 -1
- flwr/server/superlink/driver/driver_servicer.py +23 -8
- flwr/server/superlink/ffs/disk_ffs.py +6 -3
- flwr/server/superlink/ffs/ffs.py +3 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +9 -3
- flwr/server/superlink/fleet/message_handler/message_handler.py +16 -1
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/workflow/default_workflows.py +3 -1
- flwr/superexec/deployment.py +8 -9
- {flwr_nightly-1.11.0.dev20240815.dist-info → flwr_nightly-1.11.0.dev20240816.dist-info}/METADATA +1 -1
- {flwr_nightly-1.11.0.dev20240815.dist-info → flwr_nightly-1.11.0.dev20240816.dist-info}/RECORD +25 -23
- {flwr_nightly-1.11.0.dev20240815.dist-info → flwr_nightly-1.11.0.dev20240816.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.11.0.dev20240815.dist-info → flwr_nightly-1.11.0.dev20240816.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.11.0.dev20240815.dist-info → flwr_nightly-1.11.0.dev20240816.dist-info}/entry_points.txt +0 -0
flwr/client/supernode/app.py
CHANGED
|
@@ -18,7 +18,7 @@ import argparse
|
|
|
18
18
|
import sys
|
|
19
19
|
from logging import DEBUG, INFO, WARN
|
|
20
20
|
from pathlib import Path
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import Optional, Tuple
|
|
22
22
|
|
|
23
23
|
from cryptography.exceptions import UnsupportedAlgorithm
|
|
24
24
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
@@ -27,15 +27,8 @@ from cryptography.hazmat.primitives.serialization import (
|
|
|
27
27
|
load_ssh_public_key,
|
|
28
28
|
)
|
|
29
29
|
|
|
30
|
-
from flwr.client.client_app import ClientApp, LoadClientAppError
|
|
31
30
|
from flwr.common import EventType, event
|
|
32
|
-
from flwr.common.config import
|
|
33
|
-
get_flwr_dir,
|
|
34
|
-
get_metadata_from_config,
|
|
35
|
-
get_project_config,
|
|
36
|
-
get_project_dir,
|
|
37
|
-
parse_config_args,
|
|
38
|
-
)
|
|
31
|
+
from flwr.common.config import parse_config_args
|
|
39
32
|
from flwr.common.constant import (
|
|
40
33
|
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
41
34
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
@@ -43,9 +36,10 @@ from flwr.common.constant import (
|
|
|
43
36
|
)
|
|
44
37
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
45
38
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
46
|
-
from flwr.common.object_ref import load_app, validate
|
|
47
39
|
|
|
48
|
-
from ..app import
|
|
40
|
+
from ..app import start_client_internal
|
|
41
|
+
from ..process.process import run_clientapp
|
|
42
|
+
from ..process.utils import get_load_client_app_fn
|
|
49
43
|
|
|
50
44
|
ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092"
|
|
51
45
|
|
|
@@ -61,7 +55,7 @@ def run_supernode() -> None:
|
|
|
61
55
|
_warn_deprecated_server_arg(args)
|
|
62
56
|
|
|
63
57
|
root_certificates = _get_certificates(args)
|
|
64
|
-
load_fn =
|
|
58
|
+
load_fn = get_load_client_app_fn(
|
|
65
59
|
default_app_ref="",
|
|
66
60
|
app_path=args.app,
|
|
67
61
|
flwr_dir=args.flwr_dir,
|
|
@@ -69,7 +63,7 @@ def run_supernode() -> None:
|
|
|
69
63
|
)
|
|
70
64
|
authentication_keys = _try_setup_client_authentication(args)
|
|
71
65
|
|
|
72
|
-
|
|
66
|
+
start_client_internal(
|
|
73
67
|
server_address=args.superlink,
|
|
74
68
|
load_client_app_fn=load_fn,
|
|
75
69
|
transport=args.transport,
|
|
@@ -79,7 +73,8 @@ def run_supernode() -> None:
|
|
|
79
73
|
max_retries=args.max_retries,
|
|
80
74
|
max_wait_time=args.max_wait_time,
|
|
81
75
|
node_config=parse_config_args([args.node_config]),
|
|
82
|
-
|
|
76
|
+
isolate=args.isolate,
|
|
77
|
+
supernode_address=args.supernode_address,
|
|
83
78
|
)
|
|
84
79
|
|
|
85
80
|
# Graceful shutdown
|
|
@@ -99,14 +94,14 @@ def run_client_app() -> None:
|
|
|
99
94
|
_warn_deprecated_server_arg(args)
|
|
100
95
|
|
|
101
96
|
root_certificates = _get_certificates(args)
|
|
102
|
-
load_fn =
|
|
97
|
+
load_fn = get_load_client_app_fn(
|
|
103
98
|
default_app_ref=getattr(args, "client-app"),
|
|
104
99
|
app_path=args.dir,
|
|
105
100
|
multi_app=False,
|
|
106
101
|
)
|
|
107
102
|
authentication_keys = _try_setup_client_authentication(args)
|
|
108
103
|
|
|
109
|
-
|
|
104
|
+
start_client_internal(
|
|
110
105
|
server_address=args.superlink,
|
|
111
106
|
node_config=parse_config_args([args.node_config]),
|
|
112
107
|
load_client_app_fn=load_fn,
|
|
@@ -128,7 +123,7 @@ def flwr_clientapp() -> None:
|
|
|
128
123
|
description="Run a Flower ClientApp",
|
|
129
124
|
)
|
|
130
125
|
parser.add_argument(
|
|
131
|
-
"--
|
|
126
|
+
"--supernode",
|
|
132
127
|
help="Address of SuperNode ClientAppIo gRPC servicer",
|
|
133
128
|
)
|
|
134
129
|
parser.add_argument(
|
|
@@ -140,9 +135,10 @@ def flwr_clientapp() -> None:
|
|
|
140
135
|
DEBUG,
|
|
141
136
|
"Staring isolated `ClientApp` connected to SuperNode ClientAppIo at %s "
|
|
142
137
|
"with the token %s",
|
|
143
|
-
args.
|
|
138
|
+
args.supernode,
|
|
144
139
|
args.token,
|
|
145
140
|
)
|
|
141
|
+
run_clientapp(supernode=args.supernode, token=int(args.token))
|
|
146
142
|
|
|
147
143
|
|
|
148
144
|
def _warn_deprecated_server_arg(args: argparse.Namespace) -> None:
|
|
@@ -200,85 +196,6 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
|
|
|
200
196
|
return root_certificates
|
|
201
197
|
|
|
202
198
|
|
|
203
|
-
def _get_load_client_app_fn(
|
|
204
|
-
default_app_ref: str,
|
|
205
|
-
app_path: Optional[str],
|
|
206
|
-
multi_app: bool,
|
|
207
|
-
flwr_dir: Optional[str] = None,
|
|
208
|
-
) -> Callable[[str, str], ClientApp]:
|
|
209
|
-
"""Get the load_client_app_fn function.
|
|
210
|
-
|
|
211
|
-
If `multi_app` is True, this function loads the specified ClientApp
|
|
212
|
-
based on `fab_id` and `fab_version`. If `fab_id` is empty, a default
|
|
213
|
-
ClientApp will be loaded.
|
|
214
|
-
|
|
215
|
-
If `multi_app` is False, it ignores `fab_id` and `fab_version` and
|
|
216
|
-
loads a default ClientApp.
|
|
217
|
-
"""
|
|
218
|
-
if not multi_app:
|
|
219
|
-
log(
|
|
220
|
-
DEBUG,
|
|
221
|
-
"Flower SuperNode will load and validate ClientApp `%s`",
|
|
222
|
-
default_app_ref,
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
valid, error_msg = validate(default_app_ref, project_dir=app_path)
|
|
226
|
-
if not valid and error_msg:
|
|
227
|
-
raise LoadClientAppError(error_msg) from None
|
|
228
|
-
|
|
229
|
-
def _load(fab_id: str, fab_version: str) -> ClientApp:
|
|
230
|
-
runtime_app_dir = Path(app_path if app_path else "").absolute()
|
|
231
|
-
# If multi-app feature is disabled
|
|
232
|
-
if not multi_app:
|
|
233
|
-
# Set app reference
|
|
234
|
-
client_app_ref = default_app_ref
|
|
235
|
-
# If multi-app feature is enabled but app directory is provided
|
|
236
|
-
elif app_path is not None:
|
|
237
|
-
config = get_project_config(runtime_app_dir)
|
|
238
|
-
this_fab_version, this_fab_id = get_metadata_from_config(config)
|
|
239
|
-
|
|
240
|
-
if this_fab_version != fab_version or this_fab_id != fab_id:
|
|
241
|
-
raise LoadClientAppError(
|
|
242
|
-
f"FAB ID or version mismatch: Expected FAB ID '{this_fab_id}' and "
|
|
243
|
-
f"FAB version '{this_fab_version}', but received FAB ID '{fab_id}' "
|
|
244
|
-
f"and FAB version '{fab_version}'.",
|
|
245
|
-
) from None
|
|
246
|
-
|
|
247
|
-
# log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
|
|
248
|
-
|
|
249
|
-
# Set app reference
|
|
250
|
-
client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
|
|
251
|
-
# If multi-app feature is enabled
|
|
252
|
-
else:
|
|
253
|
-
try:
|
|
254
|
-
runtime_app_dir = get_project_dir(
|
|
255
|
-
fab_id, fab_version, get_flwr_dir(flwr_dir)
|
|
256
|
-
)
|
|
257
|
-
config = get_project_config(runtime_app_dir)
|
|
258
|
-
except Exception as e:
|
|
259
|
-
raise LoadClientAppError("Failed to load ClientApp") from e
|
|
260
|
-
|
|
261
|
-
# Set app reference
|
|
262
|
-
client_app_ref = config["tool"]["flwr"]["app"]["components"]["clientapp"]
|
|
263
|
-
|
|
264
|
-
# Load ClientApp
|
|
265
|
-
log(
|
|
266
|
-
DEBUG,
|
|
267
|
-
"Loading ClientApp `%s`",
|
|
268
|
-
client_app_ref,
|
|
269
|
-
)
|
|
270
|
-
client_app = load_app(client_app_ref, LoadClientAppError, runtime_app_dir)
|
|
271
|
-
|
|
272
|
-
if not isinstance(client_app, ClientApp):
|
|
273
|
-
raise LoadClientAppError(
|
|
274
|
-
f"Attribute {client_app_ref} is not of type {ClientApp}",
|
|
275
|
-
) from None
|
|
276
|
-
|
|
277
|
-
return client_app
|
|
278
|
-
|
|
279
|
-
return _load
|
|
280
|
-
|
|
281
|
-
|
|
282
199
|
def _parse_args_run_supernode() -> argparse.ArgumentParser:
|
|
283
200
|
"""Parse flower-supernode command line arguments."""
|
|
284
201
|
parser = argparse.ArgumentParser(
|
|
@@ -308,6 +225,17 @@ def _parse_args_run_supernode() -> argparse.ArgumentParser:
|
|
|
308
225
|
- `$HOME/.flwr/` in all other cases
|
|
309
226
|
""",
|
|
310
227
|
)
|
|
228
|
+
parser.add_argument(
|
|
229
|
+
"--isolate",
|
|
230
|
+
action="store_true",
|
|
231
|
+
help="Run `ClientApp` in an isolated subprocess. By default, `ClientApp` "
|
|
232
|
+
"runs in the same process that executes the SuperNode.",
|
|
233
|
+
)
|
|
234
|
+
parser.add_argument(
|
|
235
|
+
"--supernode-address",
|
|
236
|
+
default="0.0.0.0:9094",
|
|
237
|
+
help="Set the SuperNode gRPC server address. Defaults to `0.0.0.0:9094`.",
|
|
238
|
+
)
|
|
311
239
|
|
|
312
240
|
return parser
|
|
313
241
|
|
flwr/server/app.py
CHANGED
|
@@ -301,6 +301,7 @@ def run_superlink() -> None:
|
|
|
301
301
|
fleet_server = _run_fleet_api_grpc_rere(
|
|
302
302
|
address=fleet_address,
|
|
303
303
|
state_factory=state_factory,
|
|
304
|
+
ffs_factory=ffs_factory,
|
|
304
305
|
certificates=certificates,
|
|
305
306
|
interceptors=interceptors,
|
|
306
307
|
)
|
|
@@ -487,6 +488,7 @@ def _try_obtain_certificates(
|
|
|
487
488
|
def _run_fleet_api_grpc_rere(
|
|
488
489
|
address: str,
|
|
489
490
|
state_factory: StateFactory,
|
|
491
|
+
ffs_factory: FfsFactory,
|
|
490
492
|
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
491
493
|
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
492
494
|
) -> grpc.Server:
|
|
@@ -494,6 +496,7 @@ def _run_fleet_api_grpc_rere(
|
|
|
494
496
|
# Create Fleet API gRPC server
|
|
495
497
|
fleet_servicer = FleetServicer(
|
|
496
498
|
state_factory=state_factory,
|
|
499
|
+
ffs_factory=ffs_factory,
|
|
497
500
|
)
|
|
498
501
|
fleet_add_servicer_to_server_fn = add_FleetServicer_to_server
|
|
499
502
|
fleet_grpc_server = generic_create_grpc_server(
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -21,6 +21,8 @@ from logging import DEBUG, INFO, WARN
|
|
|
21
21
|
from pathlib import Path
|
|
22
22
|
from typing import Optional
|
|
23
23
|
|
|
24
|
+
from flwr.cli.config_utils import get_fab_metadata
|
|
25
|
+
from flwr.cli.install import install_from_fab
|
|
24
26
|
from flwr.common import Context, EventType, RecordSet, event
|
|
25
27
|
from flwr.common.config import (
|
|
26
28
|
get_flwr_dir,
|
|
@@ -36,6 +38,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
36
38
|
CreateRunRequest,
|
|
37
39
|
CreateRunResponse,
|
|
38
40
|
)
|
|
41
|
+
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
39
42
|
|
|
40
43
|
from .driver import Driver
|
|
41
44
|
from .driver.grpc_driver import GrpcDriver
|
|
@@ -87,7 +90,8 @@ def run(
|
|
|
87
90
|
log(DEBUG, "ServerApp finished running.")
|
|
88
91
|
|
|
89
92
|
|
|
90
|
-
|
|
93
|
+
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
|
|
94
|
+
def run_server_app() -> None:
|
|
91
95
|
"""Run Flower server app."""
|
|
92
96
|
event(EventType.RUN_SERVER_APP_ENTER)
|
|
93
97
|
|
|
@@ -164,7 +168,19 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
|
|
|
164
168
|
)
|
|
165
169
|
flwr_dir = get_flwr_dir(args.flwr_dir)
|
|
166
170
|
run_ = driver.run
|
|
167
|
-
|
|
171
|
+
if run_.fab_hash:
|
|
172
|
+
fab_req = GetFabRequest(hash_str=run_.fab_hash)
|
|
173
|
+
# pylint: disable-next=W0212
|
|
174
|
+
fab_res: GetFabResponse = driver._stub.GetFab(fab_req)
|
|
175
|
+
if fab_res.fab.hash_str != run_.fab_hash:
|
|
176
|
+
raise ValueError("FAB hashes don't match.")
|
|
177
|
+
|
|
178
|
+
install_from_fab(fab_res.fab.content, flwr_dir, True)
|
|
179
|
+
fab_id, fab_version = get_fab_metadata(fab_res.fab.content)
|
|
180
|
+
else:
|
|
181
|
+
fab_id, fab_version = run_.fab_id, run_.fab_version
|
|
182
|
+
|
|
183
|
+
app_path = str(get_project_dir(fab_id, fab_version, flwr_dir))
|
|
168
184
|
config = get_project_config(app_path)
|
|
169
185
|
else:
|
|
170
186
|
# User provided `app_dir`, but not `--run-id`
|
flwr/server/server.py
CHANGED
|
@@ -91,7 +91,7 @@ class Server:
|
|
|
91
91
|
# Initialize parameters
|
|
92
92
|
log(INFO, "[INIT]")
|
|
93
93
|
self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout)
|
|
94
|
-
log(INFO, "
|
|
94
|
+
log(INFO, "Starting evaluation of initial global parameters")
|
|
95
95
|
res = self.strategy.evaluate(0, parameters=self.parameters)
|
|
96
96
|
if res is not None:
|
|
97
97
|
log(
|
|
@@ -102,6 +102,8 @@ class Server:
|
|
|
102
102
|
)
|
|
103
103
|
history.add_loss_centralized(server_round=0, loss=res[0])
|
|
104
104
|
history.add_metrics_centralized(server_round=0, metrics=res[1])
|
|
105
|
+
else:
|
|
106
|
+
log(INFO, "Evaluation returned no results (`None`)")
|
|
105
107
|
|
|
106
108
|
# Run federated learning for num_rounds
|
|
107
109
|
start_time = timeit.default_timer()
|
|
@@ -23,7 +23,13 @@ from uuid import UUID
|
|
|
23
23
|
import grpc
|
|
24
24
|
|
|
25
25
|
from flwr.common.logger import log
|
|
26
|
-
from flwr.common.serde import
|
|
26
|
+
from flwr.common.serde import (
|
|
27
|
+
fab_from_proto,
|
|
28
|
+
fab_to_proto,
|
|
29
|
+
user_config_from_proto,
|
|
30
|
+
user_config_to_proto,
|
|
31
|
+
)
|
|
32
|
+
from flwr.common.typing import Fab
|
|
27
33
|
from flwr.proto import driver_pb2_grpc # pylint: disable=E0611
|
|
28
34
|
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
29
35
|
CreateRunRequest,
|
|
@@ -43,7 +49,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
43
49
|
Run,
|
|
44
50
|
)
|
|
45
51
|
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
46
|
-
from flwr.server.superlink.ffs import Ffs
|
|
52
|
+
from flwr.server.superlink.ffs.ffs import Ffs
|
|
47
53
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
48
54
|
from flwr.server.superlink.state import State, StateFactory
|
|
49
55
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
@@ -74,12 +80,13 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
74
80
|
"""Create run ID."""
|
|
75
81
|
log(DEBUG, "DriverServicer.CreateRun")
|
|
76
82
|
state: State = self.state_factory.state()
|
|
77
|
-
if request.HasField("fab")
|
|
83
|
+
if request.HasField("fab"):
|
|
84
|
+
fab = fab_from_proto(request.fab)
|
|
78
85
|
ffs: Ffs = self.ffs_factory.ffs()
|
|
79
|
-
fab_hash = ffs.put(
|
|
86
|
+
fab_hash = ffs.put(fab.content, {})
|
|
80
87
|
_raise_if(
|
|
81
|
-
fab_hash !=
|
|
82
|
-
f"FAB ({
|
|
88
|
+
fab_hash != fab.hash_str,
|
|
89
|
+
f"FAB ({fab.hash_str}) hash from request doesn't match contents",
|
|
83
90
|
)
|
|
84
91
|
else:
|
|
85
92
|
fab_hash = ""
|
|
@@ -174,14 +181,22 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
174
181
|
fab_id=run.fab_id,
|
|
175
182
|
fab_version=run.fab_version,
|
|
176
183
|
override_config=user_config_to_proto(run.override_config),
|
|
184
|
+
fab_hash=run.fab_hash,
|
|
177
185
|
)
|
|
178
186
|
)
|
|
179
187
|
|
|
180
188
|
def GetFab(
|
|
181
189
|
self, request: GetFabRequest, context: grpc.ServicerContext
|
|
182
190
|
) -> GetFabResponse:
|
|
183
|
-
"""
|
|
184
|
-
|
|
191
|
+
"""Get FAB from Ffs."""
|
|
192
|
+
log(DEBUG, "DriverServicer.GetFab")
|
|
193
|
+
|
|
194
|
+
ffs: Ffs = self.ffs_factory.ffs()
|
|
195
|
+
if result := ffs.get(request.hash_str):
|
|
196
|
+
fab = Fab(request.hash_str, result[0])
|
|
197
|
+
return GetFabResponse(fab=fab_to_proto(fab))
|
|
198
|
+
|
|
199
|
+
raise ValueError(f"Found no FAB with hash: {request.hash_str}")
|
|
185
200
|
|
|
186
201
|
|
|
187
202
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
import hashlib
|
|
18
18
|
import json
|
|
19
19
|
from pathlib import Path
|
|
20
|
-
from typing import Dict, List, Tuple
|
|
20
|
+
from typing import Dict, List, Optional, Tuple
|
|
21
21
|
|
|
22
22
|
from flwr.server.superlink.ffs.ffs import Ffs
|
|
23
23
|
|
|
@@ -58,7 +58,7 @@ class DiskFfs(Ffs): # pylint: disable=R0904
|
|
|
58
58
|
|
|
59
59
|
return content_hash
|
|
60
60
|
|
|
61
|
-
def get(self, key: str) -> Tuple[bytes, Dict[str, str]]:
|
|
61
|
+
def get(self, key: str) -> Optional[Tuple[bytes, Dict[str, str]]]:
|
|
62
62
|
"""Return tuple containing the object content and metadata.
|
|
63
63
|
|
|
64
64
|
Parameters
|
|
@@ -68,9 +68,12 @@ class DiskFfs(Ffs): # pylint: disable=R0904
|
|
|
68
68
|
|
|
69
69
|
Returns
|
|
70
70
|
-------
|
|
71
|
-
Tuple[bytes, Dict[str, str]]
|
|
71
|
+
Optional[Tuple[bytes, Dict[str, str]]]
|
|
72
72
|
A tuple containing the object content and metadata.
|
|
73
73
|
"""
|
|
74
|
+
if not (self.base_dir / key).exists():
|
|
75
|
+
return None
|
|
76
|
+
|
|
74
77
|
content = (self.base_dir / key).read_bytes()
|
|
75
78
|
meta = json.loads((self.base_dir / f"{key}.META").read_text())
|
|
76
79
|
|
flwr/server/superlink/ffs/ffs.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import abc
|
|
19
|
-
from typing import Dict, List, Tuple
|
|
19
|
+
from typing import Dict, List, Optional, Tuple
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class Ffs(abc.ABC): # pylint: disable=R0904
|
|
@@ -40,7 +40,7 @@ class Ffs(abc.ABC): # pylint: disable=R0904
|
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
42
|
@abc.abstractmethod
|
|
43
|
-
def get(self, key: str) -> Tuple[bytes, Dict[str, str]]:
|
|
43
|
+
def get(self, key: str) -> Optional[Tuple[bytes, Dict[str, str]]]:
|
|
44
44
|
"""Return tuple containing the object content and metadata.
|
|
45
45
|
|
|
46
46
|
Parameters
|
|
@@ -50,7 +50,7 @@ class Ffs(abc.ABC): # pylint: disable=R0904
|
|
|
50
50
|
|
|
51
51
|
Returns
|
|
52
52
|
-------
|
|
53
|
-
Tuple[bytes, Dict[str, str]]
|
|
53
|
+
Optional[Tuple[bytes, Dict[str, str]]]
|
|
54
54
|
A tuple containing the object content and metadata.
|
|
55
55
|
"""
|
|
56
56
|
|
|
@@ -35,6 +35,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
35
35
|
PushTaskResResponse,
|
|
36
36
|
)
|
|
37
37
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
38
|
+
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
38
39
|
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
39
40
|
from flwr.server.superlink.state import StateFactory
|
|
40
41
|
|
|
@@ -42,8 +43,9 @@ from flwr.server.superlink.state import StateFactory
|
|
|
42
43
|
class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
43
44
|
"""Fleet API servicer."""
|
|
44
45
|
|
|
45
|
-
def __init__(self, state_factory: StateFactory) -> None:
|
|
46
|
+
def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
|
|
46
47
|
self.state_factory = state_factory
|
|
48
|
+
self.ffs_factory = ffs_factory
|
|
47
49
|
|
|
48
50
|
def CreateNode(
|
|
49
51
|
self, request: CreateNodeRequest, context: grpc.ServicerContext
|
|
@@ -106,5 +108,9 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
|
|
|
106
108
|
def GetFab(
|
|
107
109
|
self, request: GetFabRequest, context: grpc.ServicerContext
|
|
108
110
|
) -> GetFabResponse:
|
|
109
|
-
"""
|
|
110
|
-
|
|
111
|
+
"""Get FAB."""
|
|
112
|
+
log(DEBUG, "DriverServicer.GetFab")
|
|
113
|
+
return message_handler.get_fab(
|
|
114
|
+
request=request,
|
|
115
|
+
ffs=self.ffs_factory.ffs(),
|
|
116
|
+
)
|
|
@@ -19,7 +19,9 @@ import time
|
|
|
19
19
|
from typing import List, Optional
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
-
from flwr.common.serde import user_config_to_proto
|
|
22
|
+
from flwr.common.serde import fab_to_proto, user_config_to_proto
|
|
23
|
+
from flwr.common.typing import Fab
|
|
24
|
+
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
|
|
23
25
|
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
24
26
|
CreateNodeRequest,
|
|
25
27
|
CreateNodeResponse,
|
|
@@ -40,6 +42,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
|
|
|
40
42
|
Run,
|
|
41
43
|
)
|
|
42
44
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
45
|
+
from flwr.server.superlink.ffs.ffs import Ffs
|
|
43
46
|
from flwr.server.superlink.state import State
|
|
44
47
|
|
|
45
48
|
|
|
@@ -124,5 +127,17 @@ def get_run(
|
|
|
124
127
|
fab_id=run.fab_id,
|
|
125
128
|
fab_version=run.fab_version,
|
|
126
129
|
override_config=user_config_to_proto(run.override_config),
|
|
130
|
+
fab_hash=run.fab_hash,
|
|
127
131
|
)
|
|
128
132
|
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def get_fab(
|
|
136
|
+
request: GetFabRequest, ffs: Ffs # pylint: disable=W0613
|
|
137
|
+
) -> GetFabResponse:
|
|
138
|
+
"""Get FAB."""
|
|
139
|
+
if result := ffs.get(request.hash_str):
|
|
140
|
+
fab = Fab(request.hash_str, result[0])
|
|
141
|
+
return GetFabResponse(fab=fab_to_proto(fab))
|
|
142
|
+
|
|
143
|
+
raise ValueError(f"Found no FAB with hash: {request.hash_str}")
|
|
@@ -28,7 +28,7 @@ from typing import Callable, Dict, Optional
|
|
|
28
28
|
|
|
29
29
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
30
30
|
from flwr.client.node_state import NodeState
|
|
31
|
-
from flwr.client.
|
|
31
|
+
from flwr.client.process.utils import get_load_client_app_fn
|
|
32
32
|
from flwr.common.constant import (
|
|
33
33
|
NUM_PARTITIONS_KEY,
|
|
34
34
|
PARTITION_ID_KEY,
|
|
@@ -345,7 +345,7 @@ def start_vce(
|
|
|
345
345
|
def _load() -> ClientApp:
|
|
346
346
|
|
|
347
347
|
if client_app_attr:
|
|
348
|
-
app =
|
|
348
|
+
app = get_load_client_app_fn(
|
|
349
349
|
default_app_ref=client_app_attr,
|
|
350
350
|
app_path=app_dir,
|
|
351
351
|
flwr_dir=flwr_dir,
|
|
@@ -167,7 +167,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
167
167
|
context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
|
|
168
168
|
|
|
169
169
|
# Evaluate initial parameters
|
|
170
|
-
log(INFO, "
|
|
170
|
+
log(INFO, "Starting evaluation of initial global parameters")
|
|
171
171
|
parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
|
|
172
172
|
res = context.strategy.evaluate(0, parameters=parameters)
|
|
173
173
|
if res is not None:
|
|
@@ -179,6 +179,8 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
179
179
|
)
|
|
180
180
|
context.history.add_loss_centralized(server_round=0, loss=res[0])
|
|
181
181
|
context.history.add_metrics_centralized(server_round=0, metrics=res[1])
|
|
182
|
+
else:
|
|
183
|
+
log(INFO, "Evaluation returned no results (`None`)")
|
|
182
184
|
|
|
183
185
|
|
|
184
186
|
def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None:
|
flwr/superexec/deployment.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Deployment engine executor."""
|
|
16
16
|
|
|
17
|
+
import hashlib
|
|
17
18
|
import subprocess
|
|
18
19
|
from logging import ERROR, INFO
|
|
19
20
|
from pathlib import Path
|
|
@@ -21,12 +22,11 @@ from typing import Optional
|
|
|
21
22
|
|
|
22
23
|
from typing_extensions import override
|
|
23
24
|
|
|
24
|
-
from flwr.cli.config_utils import get_fab_metadata
|
|
25
25
|
from flwr.cli.install import install_from_fab
|
|
26
26
|
from flwr.common.grpc import create_channel
|
|
27
27
|
from flwr.common.logger import log
|
|
28
|
-
from flwr.common.serde import user_config_to_proto
|
|
29
|
-
from flwr.common.typing import UserConfig
|
|
28
|
+
from flwr.common.serde import fab_to_proto, user_config_to_proto
|
|
29
|
+
from flwr.common.typing import Fab, UserConfig
|
|
30
30
|
from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
|
|
31
31
|
from flwr.proto.driver_pb2_grpc import DriverStub
|
|
32
32
|
from flwr.server.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER
|
|
@@ -113,8 +113,7 @@ class DeploymentEngine(Executor):
|
|
|
113
113
|
|
|
114
114
|
def _create_run(
|
|
115
115
|
self,
|
|
116
|
-
|
|
117
|
-
fab_version: str,
|
|
116
|
+
fab: Fab,
|
|
118
117
|
override_config: UserConfig,
|
|
119
118
|
) -> int:
|
|
120
119
|
if self.stub is None:
|
|
@@ -123,8 +122,7 @@ class DeploymentEngine(Executor):
|
|
|
123
122
|
assert self.stub is not None
|
|
124
123
|
|
|
125
124
|
req = CreateRunRequest(
|
|
126
|
-
|
|
127
|
-
fab_version=fab_version,
|
|
125
|
+
fab=fab_to_proto(fab),
|
|
128
126
|
override_config=user_config_to_proto(override_config),
|
|
129
127
|
)
|
|
130
128
|
res = self.stub.CreateRun(request=req)
|
|
@@ -140,11 +138,12 @@ class DeploymentEngine(Executor):
|
|
|
140
138
|
"""Start run using the Flower Deployment Engine."""
|
|
141
139
|
try:
|
|
142
140
|
# Install FAB to flwr dir
|
|
143
|
-
fab_version, fab_id = get_fab_metadata(fab_file)
|
|
144
141
|
install_from_fab(fab_file, None, True)
|
|
145
142
|
|
|
146
143
|
# Call SuperLink to create run
|
|
147
|
-
run_id: int = self._create_run(
|
|
144
|
+
run_id: int = self._create_run(
|
|
145
|
+
Fab(hashlib.sha256(fab_file).hexdigest(), fab_file), override_config
|
|
146
|
+
)
|
|
148
147
|
log(INFO, "Created run %s", str(run_id))
|
|
149
148
|
|
|
150
149
|
command = [
|