flwr-nightly 1.9.0.dev20240520__py3-none-any.whl → 1.10.0.dev20240612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +4 -19
- flwr/cli/config_utils.py +12 -27
- flwr/cli/install.py +196 -0
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +7 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +7 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +7 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -1
- flwr/cli/run/run.py +20 -4
- flwr/cli/utils.py +14 -0
- flwr/client/__init__.py +1 -0
- flwr/client/app.py +135 -97
- flwr/client/client_app.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +1 -1
- flwr/client/grpc_rere_client/connection.py +6 -6
- flwr/client/mod/__init__.py +1 -1
- flwr/client/rest_client/connection.py +1 -2
- flwr/client/supernode/app.py +70 -28
- flwr/common/object_ref.py +13 -9
- flwr/common/recordset_compat.py +8 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +0 -15
- flwr/proto/driver_pb2.py +20 -19
- flwr/proto/driver_pb2_grpc.py +35 -0
- flwr/proto/driver_pb2_grpc.pyi +14 -0
- flwr/proto/fleet_pb2.py +28 -33
- flwr/proto/fleet_pb2.pyi +0 -42
- flwr/proto/fleet_pb2_grpc.py +7 -6
- flwr/proto/fleet_pb2_grpc.pyi +5 -4
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/run_pb2.py +30 -0
- flwr/proto/run_pb2.pyi +52 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/server/__init__.py +0 -4
- flwr/server/app.py +190 -395
- flwr/server/run_serverapp.py +29 -5
- flwr/server/server_app.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +7 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +5 -3
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/METADATA +4 -3
- {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/RECORD +53 -44
- {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/entry_points.txt +0 -2
- {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/WHEEL +0 -0
flwr/server/app.py
CHANGED
|
@@ -15,17 +15,17 @@
|
|
|
15
15
|
"""Flower server app."""
|
|
16
16
|
|
|
17
17
|
import argparse
|
|
18
|
-
import asyncio
|
|
19
18
|
import csv
|
|
20
19
|
import importlib.util
|
|
21
20
|
import sys
|
|
22
21
|
import threading
|
|
23
|
-
from logging import
|
|
22
|
+
from logging import INFO, WARN
|
|
24
23
|
from os.path import isfile
|
|
25
24
|
from pathlib import Path
|
|
26
|
-
from typing import
|
|
25
|
+
from typing import Optional, Sequence, Set, Tuple
|
|
27
26
|
|
|
28
27
|
import grpc
|
|
28
|
+
from cryptography.exceptions import UnsupportedAlgorithm
|
|
29
29
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
30
30
|
from cryptography.hazmat.primitives.serialization import (
|
|
31
31
|
load_ssh_private_key,
|
|
@@ -38,14 +38,12 @@ from flwr.common.constant import (
|
|
|
38
38
|
MISSING_EXTRA_REST,
|
|
39
39
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
40
40
|
TRANSPORT_TYPE_REST,
|
|
41
|
-
TRANSPORT_TYPE_VCE,
|
|
42
41
|
)
|
|
43
42
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
44
|
-
from flwr.common.logger import log
|
|
43
|
+
from flwr.common.logger import log
|
|
45
44
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
46
45
|
private_key_to_bytes,
|
|
47
46
|
public_key_to_bytes,
|
|
48
|
-
ssh_types_to_elliptic_curve,
|
|
49
47
|
)
|
|
50
48
|
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
|
51
49
|
add_FleetServicer_to_server,
|
|
@@ -63,7 +61,6 @@ from .superlink.fleet.grpc_bidi.grpc_server import (
|
|
|
63
61
|
)
|
|
64
62
|
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
|
65
63
|
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
|
|
66
|
-
from .superlink.fleet.vce import start_vce
|
|
67
64
|
from .superlink.state import StateFactory
|
|
68
65
|
|
|
69
66
|
ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
@@ -193,120 +190,6 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
|
193
190
|
return hist
|
|
194
191
|
|
|
195
192
|
|
|
196
|
-
def run_driver_api() -> None:
|
|
197
|
-
"""Run Flower server (Driver API)."""
|
|
198
|
-
log(INFO, "Starting Flower server (Driver API)")
|
|
199
|
-
# Running `flower-driver-api` is deprecated
|
|
200
|
-
warn_deprecated_feature("flower-driver-api")
|
|
201
|
-
log(WARN, "Use `flower-superlink` instead")
|
|
202
|
-
event(EventType.RUN_DRIVER_API_ENTER)
|
|
203
|
-
args = _parse_args_run_driver_api().parse_args()
|
|
204
|
-
|
|
205
|
-
# Parse IP address
|
|
206
|
-
parsed_address = parse_address(args.driver_api_address)
|
|
207
|
-
if not parsed_address:
|
|
208
|
-
sys.exit(f"Driver IP address ({args.driver_api_address}) cannot be parsed.")
|
|
209
|
-
host, port, is_v6 = parsed_address
|
|
210
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
211
|
-
|
|
212
|
-
# Obtain certificates
|
|
213
|
-
certificates = _try_obtain_certificates(args)
|
|
214
|
-
|
|
215
|
-
# Initialize StateFactory
|
|
216
|
-
state_factory = StateFactory(args.database)
|
|
217
|
-
|
|
218
|
-
# Start server
|
|
219
|
-
grpc_server: grpc.Server = run_driver_api_grpc(
|
|
220
|
-
address=address,
|
|
221
|
-
state_factory=state_factory,
|
|
222
|
-
certificates=certificates,
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
# Graceful shutdown
|
|
226
|
-
register_exit_handlers(
|
|
227
|
-
event_type=EventType.RUN_DRIVER_API_LEAVE,
|
|
228
|
-
grpc_servers=[grpc_server],
|
|
229
|
-
bckg_threads=[],
|
|
230
|
-
)
|
|
231
|
-
|
|
232
|
-
# Block
|
|
233
|
-
grpc_server.wait_for_termination()
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
def run_fleet_api() -> None:
|
|
237
|
-
"""Run Flower server (Fleet API)."""
|
|
238
|
-
log(INFO, "Starting Flower server (Fleet API)")
|
|
239
|
-
# Running `flower-fleet-api` is deprecated
|
|
240
|
-
warn_deprecated_feature("flower-fleet-api")
|
|
241
|
-
log(WARN, "Use `flower-superlink` instead")
|
|
242
|
-
event(EventType.RUN_FLEET_API_ENTER)
|
|
243
|
-
args = _parse_args_run_fleet_api().parse_args()
|
|
244
|
-
|
|
245
|
-
# Obtain certificates
|
|
246
|
-
certificates = _try_obtain_certificates(args)
|
|
247
|
-
|
|
248
|
-
# Initialize StateFactory
|
|
249
|
-
state_factory = StateFactory(args.database)
|
|
250
|
-
|
|
251
|
-
grpc_servers = []
|
|
252
|
-
bckg_threads = []
|
|
253
|
-
|
|
254
|
-
# Start Fleet API
|
|
255
|
-
if args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
256
|
-
if (
|
|
257
|
-
importlib.util.find_spec("requests")
|
|
258
|
-
and importlib.util.find_spec("starlette")
|
|
259
|
-
and importlib.util.find_spec("uvicorn")
|
|
260
|
-
) is None:
|
|
261
|
-
sys.exit(MISSING_EXTRA_REST)
|
|
262
|
-
address_arg = args.rest_fleet_api_address
|
|
263
|
-
parsed_address = parse_address(address_arg)
|
|
264
|
-
if not parsed_address:
|
|
265
|
-
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
266
|
-
host, port, _ = parsed_address
|
|
267
|
-
fleet_thread = threading.Thread(
|
|
268
|
-
target=_run_fleet_api_rest,
|
|
269
|
-
args=(
|
|
270
|
-
host,
|
|
271
|
-
port,
|
|
272
|
-
args.ssl_keyfile,
|
|
273
|
-
args.ssl_certfile,
|
|
274
|
-
state_factory,
|
|
275
|
-
args.rest_fleet_api_workers,
|
|
276
|
-
),
|
|
277
|
-
)
|
|
278
|
-
fleet_thread.start()
|
|
279
|
-
bckg_threads.append(fleet_thread)
|
|
280
|
-
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
281
|
-
address_arg = args.grpc_rere_fleet_api_address
|
|
282
|
-
parsed_address = parse_address(address_arg)
|
|
283
|
-
if not parsed_address:
|
|
284
|
-
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
285
|
-
host, port, is_v6 = parsed_address
|
|
286
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
287
|
-
fleet_server = _run_fleet_api_grpc_rere(
|
|
288
|
-
address=address,
|
|
289
|
-
state_factory=state_factory,
|
|
290
|
-
certificates=certificates,
|
|
291
|
-
)
|
|
292
|
-
grpc_servers.append(fleet_server)
|
|
293
|
-
else:
|
|
294
|
-
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
|
295
|
-
|
|
296
|
-
# Graceful shutdown
|
|
297
|
-
register_exit_handlers(
|
|
298
|
-
event_type=EventType.RUN_FLEET_API_LEAVE,
|
|
299
|
-
grpc_servers=grpc_servers,
|
|
300
|
-
bckg_threads=bckg_threads,
|
|
301
|
-
)
|
|
302
|
-
|
|
303
|
-
# Block
|
|
304
|
-
if len(grpc_servers) > 0:
|
|
305
|
-
grpc_servers[0].wait_for_termination()
|
|
306
|
-
elif len(bckg_threads) > 0:
|
|
307
|
-
bckg_threads[0].join()
|
|
308
|
-
|
|
309
|
-
|
|
310
193
|
# pylint: disable=too-many-branches, too-many-locals, too-many-statements
|
|
311
194
|
def run_superlink() -> None:
|
|
312
195
|
"""Run Flower SuperLink (Driver API and Fleet API)."""
|
|
@@ -317,11 +200,15 @@ def run_superlink() -> None:
|
|
|
317
200
|
args = _parse_args_run_superlink().parse_args()
|
|
318
201
|
|
|
319
202
|
# Parse IP address
|
|
320
|
-
|
|
321
|
-
if not
|
|
203
|
+
parsed_driver_address = parse_address(args.driver_api_address)
|
|
204
|
+
if not parsed_driver_address:
|
|
322
205
|
sys.exit(f"Driver IP address ({args.driver_api_address}) cannot be parsed.")
|
|
323
|
-
|
|
324
|
-
|
|
206
|
+
driver_host, driver_port, driver_is_v6 = parsed_driver_address
|
|
207
|
+
driver_address = (
|
|
208
|
+
f"[{driver_host}]:{driver_port}"
|
|
209
|
+
if driver_is_v6
|
|
210
|
+
else f"{driver_host}:{driver_port}"
|
|
211
|
+
)
|
|
325
212
|
|
|
326
213
|
# Obtain certificates
|
|
327
214
|
certificates = _try_obtain_certificates(args)
|
|
@@ -331,13 +218,38 @@ def run_superlink() -> None:
|
|
|
331
218
|
|
|
332
219
|
# Start Driver API
|
|
333
220
|
driver_server: grpc.Server = run_driver_api_grpc(
|
|
334
|
-
address=
|
|
221
|
+
address=driver_address,
|
|
335
222
|
state_factory=state_factory,
|
|
336
223
|
certificates=certificates,
|
|
337
224
|
)
|
|
338
225
|
|
|
339
226
|
grpc_servers = [driver_server]
|
|
340
227
|
bckg_threads = []
|
|
228
|
+
if not args.fleet_api_address:
|
|
229
|
+
args.fleet_api_address = (
|
|
230
|
+
ADDRESS_FLEET_API_GRPC_RERE
|
|
231
|
+
if args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE
|
|
232
|
+
else ADDRESS_FLEET_API_REST
|
|
233
|
+
)
|
|
234
|
+
parsed_fleet_address = parse_address(args.fleet_api_address)
|
|
235
|
+
if not parsed_fleet_address:
|
|
236
|
+
sys.exit(f"Fleet IP address ({args.fleet_api_address}) cannot be parsed.")
|
|
237
|
+
fleet_host, fleet_port, fleet_is_v6 = parsed_fleet_address
|
|
238
|
+
fleet_address = (
|
|
239
|
+
f"[{fleet_host}]:{fleet_port}" if fleet_is_v6 else f"{fleet_host}:{fleet_port}"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
num_workers = args.fleet_api_num_workers
|
|
243
|
+
if num_workers != 1:
|
|
244
|
+
log(
|
|
245
|
+
WARN,
|
|
246
|
+
"The Fleet API currently supports only 1 worker. "
|
|
247
|
+
"You have specified %d workers. "
|
|
248
|
+
"Support for multiple workers will be added in future releases. "
|
|
249
|
+
"Proceeding with a single worker.",
|
|
250
|
+
args.fleet_api_num_workers,
|
|
251
|
+
)
|
|
252
|
+
num_workers = 1
|
|
341
253
|
|
|
342
254
|
# Start Fleet API
|
|
343
255
|
if args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
@@ -347,32 +259,25 @@ def run_superlink() -> None:
|
|
|
347
259
|
and importlib.util.find_spec("uvicorn")
|
|
348
260
|
) is None:
|
|
349
261
|
sys.exit(MISSING_EXTRA_REST)
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
262
|
+
|
|
263
|
+
_, ssl_certfile, ssl_keyfile = (
|
|
264
|
+
certificates if certificates is not None else (None, None, None)
|
|
265
|
+
)
|
|
266
|
+
|
|
355
267
|
fleet_thread = threading.Thread(
|
|
356
268
|
target=_run_fleet_api_rest,
|
|
357
269
|
args=(
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
270
|
+
fleet_host,
|
|
271
|
+
fleet_port,
|
|
272
|
+
ssl_keyfile,
|
|
273
|
+
ssl_certfile,
|
|
362
274
|
state_factory,
|
|
363
|
-
|
|
275
|
+
num_workers,
|
|
364
276
|
),
|
|
365
277
|
)
|
|
366
278
|
fleet_thread.start()
|
|
367
279
|
bckg_threads.append(fleet_thread)
|
|
368
280
|
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
369
|
-
address_arg = args.grpc_rere_fleet_api_address
|
|
370
|
-
parsed_address = parse_address(address_arg)
|
|
371
|
-
if not parsed_address:
|
|
372
|
-
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
373
|
-
host, port, is_v6 = parsed_address
|
|
374
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
375
|
-
|
|
376
281
|
maybe_keys = _try_setup_client_authentication(args, certificates)
|
|
377
282
|
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
|
|
378
283
|
if maybe_keys is not None:
|
|
@@ -395,23 +300,12 @@ def run_superlink() -> None:
|
|
|
395
300
|
interceptors = [AuthenticateServerInterceptor(state)]
|
|
396
301
|
|
|
397
302
|
fleet_server = _run_fleet_api_grpc_rere(
|
|
398
|
-
address=
|
|
303
|
+
address=fleet_address,
|
|
399
304
|
state_factory=state_factory,
|
|
400
305
|
certificates=certificates,
|
|
401
306
|
interceptors=interceptors,
|
|
402
307
|
)
|
|
403
308
|
grpc_servers.append(fleet_server)
|
|
404
|
-
elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
|
|
405
|
-
f_stop = asyncio.Event() # Does nothing
|
|
406
|
-
_run_fleet_api_vce(
|
|
407
|
-
num_supernodes=args.num_supernodes,
|
|
408
|
-
client_app_attr=args.client_app,
|
|
409
|
-
backend_name=args.backend,
|
|
410
|
-
backend_config_json_stream=args.backend_config,
|
|
411
|
-
app_dir=args.app_dir,
|
|
412
|
-
state_factory=state_factory,
|
|
413
|
-
f_stop=f_stop,
|
|
414
|
-
)
|
|
415
309
|
else:
|
|
416
310
|
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
|
417
311
|
|
|
@@ -435,44 +329,69 @@ def _try_setup_client_authentication(
|
|
|
435
329
|
args: argparse.Namespace,
|
|
436
330
|
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
437
331
|
) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
|
|
438
|
-
if
|
|
332
|
+
if (
|
|
333
|
+
not args.auth_list_public_keys
|
|
334
|
+
and not args.auth_superlink_private_key
|
|
335
|
+
and not args.auth_superlink_public_key
|
|
336
|
+
):
|
|
439
337
|
return None
|
|
440
338
|
|
|
339
|
+
if (
|
|
340
|
+
not args.auth_list_public_keys
|
|
341
|
+
or not args.auth_superlink_private_key
|
|
342
|
+
or not args.auth_superlink_public_key
|
|
343
|
+
):
|
|
344
|
+
sys.exit(
|
|
345
|
+
"Authentication requires providing file paths for "
|
|
346
|
+
"'--auth-list-public-keys', '--auth-superlink-private-key' and "
|
|
347
|
+
"'--auth-superlink-public-key'. Provide all three to enable authentication."
|
|
348
|
+
)
|
|
349
|
+
|
|
441
350
|
if certificates is None:
|
|
442
351
|
sys.exit(
|
|
443
|
-
"
|
|
444
|
-
"Please provide certificate paths
|
|
445
|
-
"
|
|
352
|
+
"Authentication requires secure connections. "
|
|
353
|
+
"Please provide certificate paths to `--ssl-certfile`, "
|
|
354
|
+
"`--ssl-keyfile`, and `—-ssl-ca-certfile` and try again."
|
|
446
355
|
)
|
|
447
356
|
|
|
448
|
-
client_keys_file_path = Path(args.
|
|
357
|
+
client_keys_file_path = Path(args.auth_list_public_keys)
|
|
449
358
|
if not client_keys_file_path.exists():
|
|
450
359
|
sys.exit(
|
|
451
|
-
"The provided path to the
|
|
360
|
+
"The provided path to the known public keys CSV file does not exist: "
|
|
452
361
|
f"{client_keys_file_path}. "
|
|
453
|
-
"Please provide the CSV file path containing known
|
|
454
|
-
"to '--
|
|
362
|
+
"Please provide the CSV file path containing known public keys "
|
|
363
|
+
"to '--auth-list-public-keys'."
|
|
455
364
|
)
|
|
456
365
|
|
|
457
366
|
client_public_keys: Set[bytes] = set()
|
|
458
|
-
ssh_private_key = load_ssh_private_key(
|
|
459
|
-
Path(args.require_client_authentication[1]).read_bytes(),
|
|
460
|
-
None,
|
|
461
|
-
)
|
|
462
|
-
ssh_public_key = load_ssh_public_key(
|
|
463
|
-
Path(args.require_client_authentication[2]).read_bytes()
|
|
464
|
-
)
|
|
465
367
|
|
|
466
368
|
try:
|
|
467
|
-
|
|
468
|
-
|
|
369
|
+
ssh_private_key = load_ssh_private_key(
|
|
370
|
+
Path(args.auth_superlink_private_key).read_bytes(),
|
|
371
|
+
None,
|
|
372
|
+
)
|
|
373
|
+
if not isinstance(ssh_private_key, ec.EllipticCurvePrivateKey):
|
|
374
|
+
raise ValueError()
|
|
375
|
+
except (ValueError, UnsupportedAlgorithm):
|
|
376
|
+
sys.exit(
|
|
377
|
+
"Error: Unable to parse the private key file in "
|
|
378
|
+
"'--auth-superlink-private-key'. Authentication requires elliptic "
|
|
379
|
+
"curve private and public key pair. Please ensure that the file "
|
|
380
|
+
"path points to a valid private key file and try again."
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
try:
|
|
384
|
+
ssh_public_key = load_ssh_public_key(
|
|
385
|
+
Path(args.auth_superlink_public_key).read_bytes()
|
|
469
386
|
)
|
|
470
|
-
|
|
387
|
+
if not isinstance(ssh_public_key, ec.EllipticCurvePublicKey):
|
|
388
|
+
raise ValueError()
|
|
389
|
+
except (ValueError, UnsupportedAlgorithm):
|
|
471
390
|
sys.exit(
|
|
472
|
-
"
|
|
473
|
-
"key
|
|
474
|
-
"private key pair. Please
|
|
475
|
-
"
|
|
391
|
+
"Error: Unable to parse the public key file in "
|
|
392
|
+
"'--auth-superlink-public-key'. Authentication requires elliptic "
|
|
393
|
+
"curve private and public key pair. Please ensure that the file "
|
|
394
|
+
"path points to a valid public key file and try again."
|
|
476
395
|
)
|
|
477
396
|
|
|
478
397
|
with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile:
|
|
@@ -484,14 +403,14 @@ def _try_setup_client_authentication(
|
|
|
484
403
|
client_public_keys.add(public_key_to_bytes(public_key))
|
|
485
404
|
else:
|
|
486
405
|
sys.exit(
|
|
487
|
-
"Error: Unable to parse the public keys in the
|
|
488
|
-
"file. Please ensure that the
|
|
489
|
-
"SSH public keys and try again."
|
|
406
|
+
"Error: Unable to parse the public keys in the CSV "
|
|
407
|
+
"file. Please ensure that the CSV file path points to a valid "
|
|
408
|
+
"known SSH public keys files and try again."
|
|
490
409
|
)
|
|
491
410
|
return (
|
|
492
411
|
client_public_keys,
|
|
493
|
-
|
|
494
|
-
|
|
412
|
+
ssh_private_key,
|
|
413
|
+
ssh_public_key,
|
|
495
414
|
)
|
|
496
415
|
|
|
497
416
|
|
|
@@ -501,21 +420,52 @@ def _try_obtain_certificates(
|
|
|
501
420
|
# Obtain certificates
|
|
502
421
|
if args.insecure:
|
|
503
422
|
log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
|
|
504
|
-
|
|
423
|
+
return None
|
|
505
424
|
# Check if certificates are provided
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
425
|
+
if args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
426
|
+
if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
|
|
427
|
+
if not isfile(args.ssl_ca_certfile):
|
|
428
|
+
sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
|
|
429
|
+
if not isfile(args.ssl_certfile):
|
|
430
|
+
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
|
|
431
|
+
if not isfile(args.ssl_keyfile):
|
|
432
|
+
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
|
|
433
|
+
certificates = (
|
|
434
|
+
Path(args.ssl_ca_certfile).read_bytes(), # CA certificate
|
|
435
|
+
Path(args.ssl_certfile).read_bytes(), # server certificate
|
|
436
|
+
Path(args.ssl_keyfile).read_bytes(), # server private key
|
|
437
|
+
)
|
|
438
|
+
return certificates
|
|
439
|
+
if args.ssl_certfile or args.ssl_keyfile or args.ssl_ca_certfile:
|
|
440
|
+
sys.exit(
|
|
441
|
+
"You need to provide valid file paths to `--ssl-certfile`, "
|
|
442
|
+
"`--ssl-keyfile`, and `—-ssl-ca-certfile` to create a secure "
|
|
443
|
+
"connection in Fleet API server (gRPC-rere)."
|
|
444
|
+
)
|
|
445
|
+
if args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
446
|
+
if args.ssl_certfile and args.ssl_keyfile:
|
|
447
|
+
if not isfile(args.ssl_certfile):
|
|
448
|
+
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
|
|
449
|
+
if not isfile(args.ssl_keyfile):
|
|
450
|
+
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
|
|
451
|
+
certificates = (
|
|
452
|
+
b"",
|
|
453
|
+
Path(args.ssl_certfile).read_bytes(), # server certificate
|
|
454
|
+
Path(args.ssl_keyfile).read_bytes(), # server private key
|
|
455
|
+
)
|
|
456
|
+
return certificates
|
|
457
|
+
if args.ssl_certfile or args.ssl_keyfile:
|
|
458
|
+
sys.exit(
|
|
459
|
+
"You need to provide valid file paths to `--ssl-certfile` "
|
|
460
|
+
"and `--ssl-keyfile` to create a secure connection "
|
|
461
|
+
"in Fleet API server (REST, experimental)."
|
|
462
|
+
)
|
|
463
|
+
sys.exit(
|
|
464
|
+
"Certificates are required unless running in insecure mode. "
|
|
465
|
+
"Please provide certificate paths to `--ssl-certfile`, "
|
|
466
|
+
"`--ssl-keyfile`, and `—-ssl-ca-certfile` or run the server "
|
|
467
|
+
"in insecure mode using '--insecure' if you understand the risks."
|
|
468
|
+
)
|
|
519
469
|
|
|
520
470
|
|
|
521
471
|
def _run_fleet_api_grpc_rere(
|
|
@@ -544,29 +494,6 @@ def _run_fleet_api_grpc_rere(
|
|
|
544
494
|
return fleet_grpc_server
|
|
545
495
|
|
|
546
496
|
|
|
547
|
-
# pylint: disable=too-many-arguments
|
|
548
|
-
def _run_fleet_api_vce(
|
|
549
|
-
num_supernodes: int,
|
|
550
|
-
client_app_attr: str,
|
|
551
|
-
backend_name: str,
|
|
552
|
-
backend_config_json_stream: str,
|
|
553
|
-
app_dir: str,
|
|
554
|
-
state_factory: StateFactory,
|
|
555
|
-
f_stop: asyncio.Event,
|
|
556
|
-
) -> None:
|
|
557
|
-
log(INFO, "Flower VCE: Starting Fleet API (VirtualClientEngine)")
|
|
558
|
-
|
|
559
|
-
start_vce(
|
|
560
|
-
num_supernodes=num_supernodes,
|
|
561
|
-
client_app_attr=client_app_attr,
|
|
562
|
-
backend_name=backend_name,
|
|
563
|
-
backend_config_json_stream=backend_config_json_stream,
|
|
564
|
-
state_factory=state_factory,
|
|
565
|
-
app_dir=app_dir,
|
|
566
|
-
f_stop=f_stop,
|
|
567
|
-
)
|
|
568
|
-
|
|
569
|
-
|
|
570
497
|
# pylint: disable=import-outside-toplevel,too-many-arguments
|
|
571
498
|
def _run_fleet_api_rest(
|
|
572
499
|
host: str,
|
|
@@ -574,7 +501,7 @@ def _run_fleet_api_rest(
|
|
|
574
501
|
ssl_keyfile: Optional[str],
|
|
575
502
|
ssl_certfile: Optional[str],
|
|
576
503
|
state_factory: StateFactory,
|
|
577
|
-
|
|
504
|
+
num_workers: int,
|
|
578
505
|
) -> None:
|
|
579
506
|
"""Run Driver API (REST-based)."""
|
|
580
507
|
try:
|
|
@@ -583,25 +510,12 @@ def _run_fleet_api_rest(
|
|
|
583
510
|
from flwr.server.superlink.fleet.rest_rere.rest_api import app as fast_api_app
|
|
584
511
|
except ModuleNotFoundError:
|
|
585
512
|
sys.exit(MISSING_EXTRA_REST)
|
|
586
|
-
|
|
587
|
-
raise ValueError(
|
|
588
|
-
f"The supported number of workers for the Fleet API (REST server) is "
|
|
589
|
-
f"1. Instead given {workers}. The functionality of >1 workers will be "
|
|
590
|
-
f"added in the future releases."
|
|
591
|
-
)
|
|
513
|
+
|
|
592
514
|
log(INFO, "Starting Flower REST server")
|
|
593
515
|
|
|
594
516
|
# See: https://www.starlette.io/applications/#accessing-the-app-instance
|
|
595
517
|
fast_api_app.state.STATE_FACTORY = state_factory
|
|
596
518
|
|
|
597
|
-
validation_exceptions = _validate_ssl_files(
|
|
598
|
-
ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile
|
|
599
|
-
)
|
|
600
|
-
if any(validation_exceptions):
|
|
601
|
-
# Starting with 3.11 we can use ExceptionGroup but for now
|
|
602
|
-
# this seems to be the reasonable approach.
|
|
603
|
-
raise ValueError(validation_exceptions)
|
|
604
|
-
|
|
605
519
|
uvicorn.run(
|
|
606
520
|
app="flwr.server.superlink.fleet.rest_rere.rest_api:app",
|
|
607
521
|
port=port,
|
|
@@ -610,70 +524,10 @@ def _run_fleet_api_rest(
|
|
|
610
524
|
access_log=True,
|
|
611
525
|
ssl_keyfile=ssl_keyfile,
|
|
612
526
|
ssl_certfile=ssl_certfile,
|
|
613
|
-
workers=
|
|
527
|
+
workers=num_workers,
|
|
614
528
|
)
|
|
615
529
|
|
|
616
530
|
|
|
617
|
-
def _validate_ssl_files(
|
|
618
|
-
ssl_keyfile: Optional[str], ssl_certfile: Optional[str]
|
|
619
|
-
) -> List[ValueError]:
|
|
620
|
-
validation_exceptions = []
|
|
621
|
-
|
|
622
|
-
if ssl_keyfile is not None and not isfile(ssl_keyfile):
|
|
623
|
-
msg = "Path argument `--ssl-keyfile` does not point to a file."
|
|
624
|
-
log(ERROR, msg)
|
|
625
|
-
validation_exceptions.append(ValueError(msg))
|
|
626
|
-
|
|
627
|
-
if ssl_certfile is not None and not isfile(ssl_certfile):
|
|
628
|
-
msg = "Path argument `--ssl-certfile` does not point to a file."
|
|
629
|
-
log(ERROR, msg)
|
|
630
|
-
validation_exceptions.append(ValueError(msg))
|
|
631
|
-
|
|
632
|
-
if not bool(ssl_keyfile) == bool(ssl_certfile):
|
|
633
|
-
msg = (
|
|
634
|
-
"When setting one of `--ssl-keyfile` and "
|
|
635
|
-
"`--ssl-certfile`, both have to be used."
|
|
636
|
-
)
|
|
637
|
-
log(ERROR, msg)
|
|
638
|
-
validation_exceptions.append(ValueError(msg))
|
|
639
|
-
|
|
640
|
-
return validation_exceptions
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
def _parse_args_run_driver_api() -> argparse.ArgumentParser:
|
|
644
|
-
"""Parse command line arguments for Driver API."""
|
|
645
|
-
parser = argparse.ArgumentParser(
|
|
646
|
-
description="Start a Flower Driver API server. "
|
|
647
|
-
"This server will be responsible for "
|
|
648
|
-
"receiving TaskIns from the Driver script and "
|
|
649
|
-
"sending them to the Fleet API. Once the client nodes "
|
|
650
|
-
"are done, they will send the TaskRes back to this Driver API server (through"
|
|
651
|
-
" the Fleet API) which will then send them back to the Driver script.",
|
|
652
|
-
)
|
|
653
|
-
|
|
654
|
-
_add_args_common(parser=parser)
|
|
655
|
-
_add_args_driver_api(parser=parser)
|
|
656
|
-
|
|
657
|
-
return parser
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
def _parse_args_run_fleet_api() -> argparse.ArgumentParser:
|
|
661
|
-
"""Parse command line arguments for Fleet API."""
|
|
662
|
-
parser = argparse.ArgumentParser(
|
|
663
|
-
description="Start a Flower Fleet API server."
|
|
664
|
-
"This server will be responsible for "
|
|
665
|
-
"sending TaskIns (received from the Driver API) to the client nodes "
|
|
666
|
-
"and of receiving TaskRes sent back from those same client nodes once "
|
|
667
|
-
"they are done. Then, this Fleet API server can send those "
|
|
668
|
-
"TaskRes back to the Driver API.",
|
|
669
|
-
)
|
|
670
|
-
|
|
671
|
-
_add_args_common(parser=parser)
|
|
672
|
-
_add_args_fleet_api(parser=parser)
|
|
673
|
-
|
|
674
|
-
return parser
|
|
675
|
-
|
|
676
|
-
|
|
677
531
|
def _parse_args_run_superlink() -> argparse.ArgumentParser:
|
|
678
532
|
"""Parse command line arguments for both Driver API and Fleet API."""
|
|
679
533
|
parser = argparse.ArgumentParser(
|
|
@@ -696,13 +550,23 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
696
550
|
"Use this flag only if you understand the risks.",
|
|
697
551
|
)
|
|
698
552
|
parser.add_argument(
|
|
699
|
-
"--
|
|
700
|
-
|
|
701
|
-
|
|
553
|
+
"--ssl-certfile",
|
|
554
|
+
help="Fleet API server SSL certificate file (as a path str) "
|
|
555
|
+
"to create a secure connection.",
|
|
556
|
+
type=str,
|
|
557
|
+
default=None,
|
|
558
|
+
)
|
|
559
|
+
parser.add_argument(
|
|
560
|
+
"--ssl-keyfile",
|
|
561
|
+
help="Fleet API server SSL private key file (as a path str) "
|
|
562
|
+
"to create a secure connection.",
|
|
563
|
+
type=str,
|
|
564
|
+
)
|
|
565
|
+
parser.add_argument(
|
|
566
|
+
"--ssl-ca-certfile",
|
|
567
|
+
help="Fleet API server SSL CA certificate file (as a path str) "
|
|
568
|
+
"to create a secure connection.",
|
|
702
569
|
type=str,
|
|
703
|
-
help="Paths to the CA certificate, server certificate, and server private "
|
|
704
|
-
"key, in that order. Note: The server can only be started without "
|
|
705
|
-
"certificates by enabling the `--insecure` flag.",
|
|
706
570
|
)
|
|
707
571
|
parser.add_argument(
|
|
708
572
|
"--database",
|
|
@@ -714,116 +578,47 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
714
578
|
default=DATABASE,
|
|
715
579
|
)
|
|
716
580
|
parser.add_argument(
|
|
717
|
-
"--
|
|
718
|
-
nargs=3,
|
|
719
|
-
metavar=("CLIENT_KEYS", "SERVER_PRIVATE_KEY", "SERVER_PUBLIC_KEY"),
|
|
581
|
+
"--auth-list-public-keys",
|
|
720
582
|
type=str,
|
|
721
|
-
help="
|
|
722
|
-
"
|
|
723
|
-
|
|
583
|
+
help="A CSV file (as a path str) containing a list of known public "
|
|
584
|
+
"keys to enable authentication.",
|
|
585
|
+
)
|
|
586
|
+
parser.add_argument(
|
|
587
|
+
"--auth-superlink-private-key",
|
|
588
|
+
type=str,
|
|
589
|
+
help="The SuperLink's private key (as a path str) to enable authentication.",
|
|
590
|
+
)
|
|
591
|
+
parser.add_argument(
|
|
592
|
+
"--auth-superlink-public-key",
|
|
593
|
+
type=str,
|
|
594
|
+
help="The SuperLink's public key (as a path str) to enable authentication.",
|
|
724
595
|
)
|
|
725
596
|
|
|
726
597
|
|
|
727
598
|
def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
|
|
728
599
|
parser.add_argument(
|
|
729
600
|
"--driver-api-address",
|
|
730
|
-
help="Driver API (gRPC) server address (IPv4, IPv6, or a domain name)",
|
|
601
|
+
help="Driver API (gRPC) server address (IPv4, IPv6, or a domain name).",
|
|
731
602
|
default=ADDRESS_DRIVER_API,
|
|
732
603
|
)
|
|
733
604
|
|
|
734
605
|
|
|
735
606
|
def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
|
|
736
607
|
# Fleet API transport layer type
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
"--grpc-rere",
|
|
740
|
-
action="store_const",
|
|
741
|
-
dest="fleet_api_type",
|
|
742
|
-
const=TRANSPORT_TYPE_GRPC_RERE,
|
|
608
|
+
parser.add_argument(
|
|
609
|
+
"--fleet-api-type",
|
|
743
610
|
default=TRANSPORT_TYPE_GRPC_RERE,
|
|
744
|
-
help="Start a Fleet API server (gRPC-rere)",
|
|
745
|
-
)
|
|
746
|
-
ex_group.add_argument(
|
|
747
|
-
"--rest",
|
|
748
|
-
action="store_const",
|
|
749
|
-
dest="fleet_api_type",
|
|
750
|
-
const=TRANSPORT_TYPE_REST,
|
|
751
|
-
help="Start a Fleet API server (REST, experimental)",
|
|
752
|
-
)
|
|
753
|
-
|
|
754
|
-
ex_group.add_argument(
|
|
755
|
-
"--vce",
|
|
756
|
-
action="store_const",
|
|
757
|
-
dest="fleet_api_type",
|
|
758
|
-
const=TRANSPORT_TYPE_VCE,
|
|
759
|
-
help="Start a Fleet API server (VirtualClientEngine)",
|
|
760
|
-
)
|
|
761
|
-
|
|
762
|
-
# Fleet API gRPC-rere options
|
|
763
|
-
grpc_rere_group = parser.add_argument_group(
|
|
764
|
-
"Fleet API (gRPC-rere) server options", ""
|
|
765
|
-
)
|
|
766
|
-
grpc_rere_group.add_argument(
|
|
767
|
-
"--grpc-rere-fleet-api-address",
|
|
768
|
-
help="Fleet API (gRPC-rere) server address (IPv4, IPv6, or a domain name)",
|
|
769
|
-
default=ADDRESS_FLEET_API_GRPC_RERE,
|
|
770
|
-
)
|
|
771
|
-
|
|
772
|
-
# Fleet API REST options
|
|
773
|
-
rest_group = parser.add_argument_group("Fleet API (REST) server options", "")
|
|
774
|
-
rest_group.add_argument(
|
|
775
|
-
"--rest-fleet-api-address",
|
|
776
|
-
help="Fleet API (REST) server address (IPv4, IPv6, or a domain name)",
|
|
777
|
-
default=ADDRESS_FLEET_API_REST,
|
|
778
|
-
)
|
|
779
|
-
rest_group.add_argument(
|
|
780
|
-
"--ssl-certfile",
|
|
781
|
-
help="Fleet API (REST) server SSL certificate file (as a path str), "
|
|
782
|
-
"needed for using 'https'.",
|
|
783
|
-
default=None,
|
|
784
|
-
)
|
|
785
|
-
rest_group.add_argument(
|
|
786
|
-
"--ssl-keyfile",
|
|
787
|
-
help="Fleet API (REST) server SSL private key file (as a path str), "
|
|
788
|
-
"needed for using 'https'.",
|
|
789
|
-
default=None,
|
|
790
|
-
)
|
|
791
|
-
rest_group.add_argument(
|
|
792
|
-
"--rest-fleet-api-workers",
|
|
793
|
-
help="Set the number of concurrent workers for the Fleet API REST server.",
|
|
794
|
-
type=int,
|
|
795
|
-
default=1,
|
|
796
|
-
)
|
|
797
|
-
|
|
798
|
-
# Fleet API VCE options
|
|
799
|
-
vce_group = parser.add_argument_group("Fleet API (VCE) server options", "")
|
|
800
|
-
vce_group.add_argument(
|
|
801
|
-
"--client-app",
|
|
802
|
-
help="For example: `client:app` or `project.package.module:wrapper.app`.",
|
|
803
|
-
)
|
|
804
|
-
vce_group.add_argument(
|
|
805
|
-
"--num-supernodes",
|
|
806
|
-
type=int,
|
|
807
|
-
help="Number of simulated SuperNodes.",
|
|
808
|
-
)
|
|
809
|
-
vce_group.add_argument(
|
|
810
|
-
"--backend",
|
|
811
|
-
default="ray",
|
|
812
611
|
type=str,
|
|
813
|
-
|
|
612
|
+
choices=[TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST],
|
|
613
|
+
help="Start a gRPC-rere or REST (experimental) Fleet API server.",
|
|
814
614
|
)
|
|
815
|
-
|
|
816
|
-
"--
|
|
817
|
-
|
|
818
|
-
default='{"client_resources": {"num_cpus":1, "num_gpus":0.0}, "tensorflow": 0}',
|
|
819
|
-
help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
|
|
820
|
-
"configure a backend. Values supported in <value> are those included by "
|
|
821
|
-
"`flwr.common.typing.ConfigsRecordValues`. ",
|
|
615
|
+
parser.add_argument(
|
|
616
|
+
"--fleet-api-address",
|
|
617
|
+
help="Fleet API server address (IPv4, IPv6, or a domain name).",
|
|
822
618
|
)
|
|
823
619
|
parser.add_argument(
|
|
824
|
-
"--
|
|
825
|
-
default=
|
|
826
|
-
|
|
827
|
-
"
|
|
828
|
-
" Default: current working directory.",
|
|
620
|
+
"--fleet-api-num-workers",
|
|
621
|
+
default=1,
|
|
622
|
+
type=int,
|
|
623
|
+
help="Set the number of concurrent workers for the Fleet API server.",
|
|
829
624
|
)
|