flwr 1.14.0__py3-none-any.whl → 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/auth_plugin/__init__.py +31 -0
- flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
- flwr/cli/cli_user_auth_interceptor.py +6 -2
- flwr/cli/config_utils.py +24 -147
- flwr/cli/constant.py +27 -0
- flwr/cli/install.py +1 -1
- flwr/cli/log.py +18 -3
- flwr/cli/login/login.py +43 -8
- flwr/cli/ls.py +14 -5
- flwr/cli/new/templates/app/README.md.tpl +3 -2
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/run/run.py +21 -11
- flwr/cli/stop.py +13 -4
- flwr/cli/utils.py +54 -40
- flwr/client/app.py +36 -48
- flwr/client/clientapp/app.py +19 -25
- flwr/client/clientapp/utils.py +1 -1
- flwr/client/grpc_client/connection.py +1 -12
- flwr/client/grpc_rere_client/client_interceptor.py +19 -119
- flwr/client/grpc_rere_client/connection.py +46 -36
- flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
- flwr/client/message_handler/task_handler.py +0 -17
- flwr/client/rest_client/connection.py +34 -26
- flwr/client/supernode/app.py +18 -72
- flwr/common/args.py +25 -47
- flwr/common/auth_plugin/auth_plugin.py +34 -23
- flwr/common/config.py +166 -16
- flwr/common/constant.py +22 -9
- flwr/common/differential_privacy.py +2 -1
- flwr/common/exit/__init__.py +24 -0
- flwr/common/exit/exit.py +99 -0
- flwr/common/exit/exit_code.py +93 -0
- flwr/common/exit_handlers.py +24 -10
- flwr/common/grpc.py +167 -4
- flwr/common/logger.py +26 -7
- flwr/common/record/recordset.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
- flwr/common/serde.py +6 -4
- flwr/common/typing.py +20 -0
- flwr/proto/clientappio_pb2.py +1 -1
- flwr/proto/error_pb2.py +1 -1
- flwr/proto/exec_pb2.py +13 -25
- flwr/proto/exec_pb2.pyi +27 -54
- flwr/proto/fab_pb2.py +1 -1
- flwr/proto/fleet_pb2.py +31 -31
- flwr/proto/fleet_pb2.pyi +23 -23
- flwr/proto/fleet_pb2_grpc.py +30 -30
- flwr/proto/fleet_pb2_grpc.pyi +20 -20
- flwr/proto/grpcadapter_pb2.py +1 -1
- flwr/proto/log_pb2.py +1 -1
- flwr/proto/message_pb2.py +1 -1
- flwr/proto/node_pb2.py +3 -3
- flwr/proto/node_pb2.pyi +1 -4
- flwr/proto/recordset_pb2.py +1 -1
- flwr/proto/run_pb2.py +1 -1
- flwr/proto/serverappio_pb2.py +24 -25
- flwr/proto/serverappio_pb2.pyi +26 -32
- flwr/proto/serverappio_pb2_grpc.py +28 -28
- flwr/proto/serverappio_pb2_grpc.pyi +16 -16
- flwr/proto/simulationio_pb2.py +1 -1
- flwr/proto/task_pb2.py +1 -1
- flwr/proto/transport_pb2.py +1 -1
- flwr/server/app.py +116 -128
- flwr/server/compat/app_utils.py +0 -1
- flwr/server/compat/driver_client_proxy.py +1 -2
- flwr/server/driver/grpc_driver.py +32 -27
- flwr/server/driver/inmemory_driver.py +2 -1
- flwr/server/serverapp/app.py +12 -10
- flwr/server/superlink/driver/serverappio_grpc.py +1 -1
- flwr/server/superlink/driver/serverappio_servicer.py +74 -48
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -24
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +97 -168
- flwr/server/superlink/fleet/message_handler/message_handler.py +37 -24
- flwr/server/superlink/fleet/rest_rere/rest_api.py +16 -18
- flwr/server/superlink/fleet/vce/vce_api.py +2 -2
- flwr/server/superlink/linkstate/in_memory_linkstate.py +45 -75
- flwr/server/superlink/linkstate/linkstate.py +17 -38
- flwr/server/superlink/linkstate/sqlite_linkstate.py +81 -145
- flwr/server/superlink/linkstate/utils.py +18 -8
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/utils/validator.py +9 -34
- flwr/simulation/app.py +4 -6
- flwr/simulation/legacy_app.py +4 -2
- flwr/simulation/run_simulation.py +1 -1
- flwr/simulation/simulationio_connection.py +2 -1
- flwr/superexec/exec_grpc.py +1 -1
- flwr/superexec/exec_servicer.py +23 -2
- {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/METADATA +8 -8
- {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/RECORD +102 -96
- {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/LICENSE +0 -0
- {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/WHEEL +0 -0
- {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/entry_points.txt +0 -0
flwr/server/app.py
CHANGED
|
@@ -18,7 +18,9 @@
|
|
|
18
18
|
import argparse
|
|
19
19
|
import csv
|
|
20
20
|
import importlib.util
|
|
21
|
-
import
|
|
21
|
+
import multiprocessing
|
|
22
|
+
import multiprocessing.context
|
|
23
|
+
import os
|
|
22
24
|
import sys
|
|
23
25
|
import threading
|
|
24
26
|
from collections.abc import Sequence
|
|
@@ -29,12 +31,8 @@ from typing import Any, Optional
|
|
|
29
31
|
|
|
30
32
|
import grpc
|
|
31
33
|
import yaml
|
|
32
|
-
from cryptography.exceptions import UnsupportedAlgorithm
|
|
33
34
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
34
|
-
from cryptography.hazmat.primitives.serialization import
|
|
35
|
-
load_ssh_private_key,
|
|
36
|
-
load_ssh_public_key,
|
|
37
|
-
)
|
|
35
|
+
from cryptography.hazmat.primitives.serialization import load_ssh_public_key
|
|
38
36
|
|
|
39
37
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
40
38
|
from flwr.common.address import parse_address
|
|
@@ -42,7 +40,7 @@ from flwr.common.args import try_obtain_server_certificates
|
|
|
42
40
|
from flwr.common.auth_plugin import ExecAuthPlugin
|
|
43
41
|
from flwr.common.config import get_flwr_dir, parse_config_args
|
|
44
42
|
from flwr.common.constant import (
|
|
45
|
-
|
|
43
|
+
AUTH_TYPE_KEY,
|
|
46
44
|
CLIENT_OCTET,
|
|
47
45
|
EXEC_API_DEFAULT_SERVER_ADDRESS,
|
|
48
46
|
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
|
|
@@ -50,7 +48,6 @@ from flwr.common.constant import (
|
|
|
50
48
|
FLEET_API_REST_DEFAULT_ADDRESS,
|
|
51
49
|
ISOLATION_MODE_PROCESS,
|
|
52
50
|
ISOLATION_MODE_SUBPROCESS,
|
|
53
|
-
MISSING_EXTRA_REST,
|
|
54
51
|
SERVER_OCTET,
|
|
55
52
|
SERVERAPPIO_API_DEFAULT_SERVER_ADDRESS,
|
|
56
53
|
SIMULATIONIO_API_DEFAULT_SERVER_ADDRESS,
|
|
@@ -58,16 +55,19 @@ from flwr.common.constant import (
|
|
|
58
55
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
59
56
|
TRANSPORT_TYPE_REST,
|
|
60
57
|
)
|
|
58
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
|
61
59
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
60
|
+
from flwr.common.grpc import generic_create_grpc_server
|
|
62
61
|
from flwr.common.logger import log, warn_deprecated_feature
|
|
63
62
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
64
|
-
private_key_to_bytes,
|
|
65
63
|
public_key_to_bytes,
|
|
66
64
|
)
|
|
67
65
|
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
|
68
66
|
add_FleetServicer_to_server,
|
|
69
67
|
)
|
|
70
68
|
from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
|
|
69
|
+
from flwr.server.serverapp.app import flwr_serverapp
|
|
70
|
+
from flwr.simulation.app import flwr_simulation
|
|
71
71
|
from flwr.superexec.app import load_executor
|
|
72
72
|
from flwr.superexec.exec_grpc import run_exec_api_grpc
|
|
73
73
|
|
|
@@ -79,10 +79,7 @@ from .strategy import Strategy
|
|
|
79
79
|
from .superlink.driver.serverappio_grpc import run_serverappio_api_grpc
|
|
80
80
|
from .superlink.ffs.ffs_factory import FfsFactory
|
|
81
81
|
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
|
|
82
|
-
from .superlink.fleet.grpc_bidi.grpc_server import
|
|
83
|
-
generic_create_grpc_server,
|
|
84
|
-
start_grpc_server,
|
|
85
|
-
)
|
|
82
|
+
from .superlink.fleet.grpc_bidi.grpc_server import start_grpc_server
|
|
86
83
|
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
|
87
84
|
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
|
|
88
85
|
from .superlink.linkstate import LinkStateFactory
|
|
@@ -226,6 +223,13 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
|
226
223
|
"enabled" if certificates is not None else "disabled",
|
|
227
224
|
)
|
|
228
225
|
|
|
226
|
+
# Graceful shutdown
|
|
227
|
+
register_exit_handlers(
|
|
228
|
+
event_type=EventType.START_SERVER_LEAVE,
|
|
229
|
+
exit_message="Flower server terminated gracefully.",
|
|
230
|
+
grpc_servers=[grpc_server],
|
|
231
|
+
)
|
|
232
|
+
|
|
229
233
|
# Start training
|
|
230
234
|
hist = run_fl(
|
|
231
235
|
server=initialized_server,
|
|
@@ -261,13 +265,16 @@ def run_superlink() -> None:
|
|
|
261
265
|
simulationio_address, _, _ = _format_address(args.simulationio_api_address)
|
|
262
266
|
|
|
263
267
|
# Obtain certificates
|
|
264
|
-
certificates = try_obtain_server_certificates(args
|
|
268
|
+
certificates = try_obtain_server_certificates(args)
|
|
269
|
+
|
|
270
|
+
# Disable the user auth TLS check if args.disable_oidc_tls_cert_verification is
|
|
271
|
+
# provided
|
|
272
|
+
verify_tls_cert = not getattr(args, "disable_oidc_tls_cert_verification", None)
|
|
265
273
|
|
|
266
|
-
user_auth_config = _try_obtain_user_auth_config(args)
|
|
267
274
|
auth_plugin: Optional[ExecAuthPlugin] = None
|
|
268
|
-
#
|
|
269
|
-
if
|
|
270
|
-
auth_plugin = _try_obtain_exec_auth_plugin(
|
|
275
|
+
# Load the auth plugin if the args.user_auth_config is provided
|
|
276
|
+
if cfg_path := getattr(args, "user_auth_config", None):
|
|
277
|
+
auth_plugin = _try_obtain_exec_auth_plugin(Path(cfg_path), verify_tls_cert)
|
|
271
278
|
|
|
272
279
|
# Initialize StateFactory
|
|
273
280
|
state_factory = LinkStateFactory(args.database)
|
|
@@ -293,7 +300,7 @@ def run_superlink() -> None:
|
|
|
293
300
|
# Determine Exec plugin
|
|
294
301
|
# If simulation is used, don't start ServerAppIo and Fleet APIs
|
|
295
302
|
sim_exec = executor.__class__.__qualname__ == "SimulationEngine"
|
|
296
|
-
bckg_threads = []
|
|
303
|
+
bckg_threads: list[threading.Thread] = []
|
|
297
304
|
|
|
298
305
|
if sim_exec:
|
|
299
306
|
simulationio_server: grpc.Server = run_simulationio_api_grpc(
|
|
@@ -344,48 +351,40 @@ def run_superlink() -> None:
|
|
|
344
351
|
and importlib.util.find_spec("starlette")
|
|
345
352
|
and importlib.util.find_spec("uvicorn")
|
|
346
353
|
) is None:
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
_, ssl_certfile, ssl_keyfile = (
|
|
350
|
-
certificates if certificates is not None else (None, None, None)
|
|
351
|
-
)
|
|
354
|
+
flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
|
|
352
355
|
|
|
353
356
|
fleet_thread = threading.Thread(
|
|
354
357
|
target=_run_fleet_api_rest,
|
|
355
358
|
args=(
|
|
356
359
|
host,
|
|
357
360
|
port,
|
|
358
|
-
ssl_keyfile,
|
|
359
|
-
ssl_certfile,
|
|
361
|
+
args.ssl_keyfile,
|
|
362
|
+
args.ssl_certfile,
|
|
360
363
|
state_factory,
|
|
361
364
|
ffs_factory,
|
|
362
365
|
num_workers,
|
|
363
366
|
),
|
|
367
|
+
daemon=True,
|
|
364
368
|
)
|
|
365
369
|
fleet_thread.start()
|
|
366
370
|
bckg_threads.append(fleet_thread)
|
|
367
371
|
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
if
|
|
371
|
-
|
|
372
|
-
node_public_keys,
|
|
373
|
-
server_private_key,
|
|
374
|
-
server_public_key,
|
|
375
|
-
) = maybe_keys
|
|
372
|
+
node_public_keys = _try_load_public_keys_node_authentication(args)
|
|
373
|
+
auto_auth = True
|
|
374
|
+
if node_public_keys is not None:
|
|
375
|
+
auto_auth = False
|
|
376
376
|
state = state_factory.state()
|
|
377
|
-
state.
|
|
377
|
+
state.clear_supernode_auth_keys()
|
|
378
378
|
state.store_node_public_keys(node_public_keys)
|
|
379
|
-
state.store_server_private_public_key(
|
|
380
|
-
private_key_to_bytes(server_private_key),
|
|
381
|
-
public_key_to_bytes(server_public_key),
|
|
382
|
-
)
|
|
383
379
|
log(
|
|
384
380
|
INFO,
|
|
385
381
|
"Node authentication enabled with %d known public keys",
|
|
386
382
|
len(node_public_keys),
|
|
387
383
|
)
|
|
388
|
-
|
|
384
|
+
else:
|
|
385
|
+
log(DEBUG, "Automatic node authentication enabled")
|
|
386
|
+
|
|
387
|
+
interceptors = [AuthenticateServerInterceptor(state_factory, auto_auth)]
|
|
389
388
|
|
|
390
389
|
fleet_server = _run_fleet_api_grpc_rere(
|
|
391
390
|
address=fleet_address,
|
|
@@ -427,6 +426,7 @@ def run_superlink() -> None:
|
|
|
427
426
|
address,
|
|
428
427
|
cmd,
|
|
429
428
|
),
|
|
429
|
+
daemon=True,
|
|
430
430
|
)
|
|
431
431
|
scheduler_th.start()
|
|
432
432
|
bckg_threads.append(scheduler_th)
|
|
@@ -434,17 +434,37 @@ def run_superlink() -> None:
|
|
|
434
434
|
# Graceful shutdown
|
|
435
435
|
register_exit_handlers(
|
|
436
436
|
event_type=EventType.RUN_SUPERLINK_LEAVE,
|
|
437
|
+
exit_message="SuperLink terminated gracefully.",
|
|
437
438
|
grpc_servers=grpc_servers,
|
|
438
|
-
bckg_threads=bckg_threads,
|
|
439
439
|
)
|
|
440
440
|
|
|
441
|
-
# Block
|
|
442
|
-
while
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
441
|
+
# Block until a thread exits prematurely
|
|
442
|
+
while all(thread.is_alive() for thread in bckg_threads):
|
|
443
|
+
sleep(0.1)
|
|
444
|
+
|
|
445
|
+
# Exit if any thread has exited prematurely
|
|
446
|
+
# This code will not be reached if the SuperLink stops gracefully
|
|
447
|
+
flwr_exit(ExitCode.SUPERLINK_THREAD_CRASH)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def _run_flwr_command(args: list[str], main_pid: int) -> None:
|
|
451
|
+
# Monitor the main process in case of SIGKILL
|
|
452
|
+
def main_process_monitor() -> None:
|
|
453
|
+
while True:
|
|
454
|
+
sleep(1)
|
|
455
|
+
if os.getppid() != main_pid:
|
|
456
|
+
os.kill(os.getpid(), 9)
|
|
457
|
+
|
|
458
|
+
threading.Thread(target=main_process_monitor, daemon=True).start()
|
|
459
|
+
|
|
460
|
+
# Run the command
|
|
461
|
+
sys.argv = args
|
|
462
|
+
if args[0] == "flwr-serverapp":
|
|
463
|
+
flwr_serverapp()
|
|
464
|
+
elif args[0] == "flwr-simulation":
|
|
465
|
+
flwr_simulation()
|
|
466
|
+
else:
|
|
467
|
+
raise ValueError(f"Unknown command: {args[0]}")
|
|
448
468
|
|
|
449
469
|
|
|
450
470
|
def _flwr_scheduler(
|
|
@@ -454,15 +474,18 @@ def _flwr_scheduler(
|
|
|
454
474
|
cmd: str,
|
|
455
475
|
) -> None:
|
|
456
476
|
log(DEBUG, "Started %s scheduler thread.", cmd)
|
|
457
|
-
|
|
458
477
|
state = state_factory.state()
|
|
478
|
+
run_id_to_proc: dict[int, multiprocessing.context.SpawnProcess] = {}
|
|
479
|
+
|
|
480
|
+
# Use the "spawn" start method for multiprocessing.
|
|
481
|
+
mp_spawn_context = multiprocessing.get_context("spawn")
|
|
459
482
|
|
|
460
483
|
# Periodically check for a pending run in the LinkState
|
|
461
484
|
while True:
|
|
462
|
-
sleep(
|
|
485
|
+
sleep(0.1)
|
|
463
486
|
pending_run_id = state.get_pending_run_id()
|
|
464
487
|
|
|
465
|
-
if pending_run_id:
|
|
488
|
+
if pending_run_id and pending_run_id not in run_id_to_proc:
|
|
466
489
|
|
|
467
490
|
log(
|
|
468
491
|
INFO,
|
|
@@ -479,50 +502,45 @@ def _flwr_scheduler(
|
|
|
479
502
|
"--insecure",
|
|
480
503
|
]
|
|
481
504
|
|
|
482
|
-
|
|
483
|
-
command,
|
|
484
|
-
text=True,
|
|
505
|
+
proc = mp_spawn_context.Process(
|
|
506
|
+
target=_run_flwr_command, args=(command, os.getpid()), daemon=True
|
|
485
507
|
)
|
|
508
|
+
proc.start()
|
|
509
|
+
|
|
510
|
+
# Store the process
|
|
511
|
+
run_id_to_proc[pending_run_id] = proc
|
|
512
|
+
|
|
513
|
+
# Clean up finished processes
|
|
514
|
+
for run_id, proc in list(run_id_to_proc.items()):
|
|
515
|
+
if not proc.is_alive():
|
|
516
|
+
del run_id_to_proc[run_id]
|
|
486
517
|
|
|
487
518
|
|
|
488
519
|
def _format_address(address: str) -> tuple[str, str, int]:
|
|
489
520
|
parsed_address = parse_address(address)
|
|
490
521
|
if not parsed_address:
|
|
491
|
-
|
|
492
|
-
|
|
522
|
+
flwr_exit(
|
|
523
|
+
ExitCode.COMMON_ADDRESS_INVALID,
|
|
524
|
+
f"Address ({address}) cannot be parsed.",
|
|
493
525
|
)
|
|
494
526
|
host, port, is_v6 = parsed_address
|
|
495
527
|
return (f"[{host}]:{port}" if is_v6 else f"{host}:{port}", host, port)
|
|
496
528
|
|
|
497
529
|
|
|
498
|
-
def
|
|
530
|
+
def _try_load_public_keys_node_authentication(
|
|
499
531
|
args: argparse.Namespace,
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
if
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
if (
|
|
510
|
-
not args.auth_list_public_keys
|
|
511
|
-
or not args.auth_superlink_private_key
|
|
512
|
-
or not args.auth_superlink_public_key
|
|
513
|
-
):
|
|
514
|
-
sys.exit(
|
|
515
|
-
"Authentication requires providing file paths for "
|
|
516
|
-
"'--auth-list-public-keys', '--auth-superlink-private-key' and "
|
|
517
|
-
"'--auth-superlink-public-key'. Provide all three to enable authentication."
|
|
532
|
+
) -> Optional[set[bytes]]:
|
|
533
|
+
"""Return a set of node public keys."""
|
|
534
|
+
if args.auth_superlink_private_key or args.auth_superlink_public_key:
|
|
535
|
+
log(
|
|
536
|
+
WARN,
|
|
537
|
+
"The `--auth-superlink-private-key` and `--auth-superlink-public-key` "
|
|
538
|
+
"arguments are deprecated and will be removed in a future release. Node "
|
|
539
|
+
"authentication no longer requires these arguments.",
|
|
518
540
|
)
|
|
519
541
|
|
|
520
|
-
if
|
|
521
|
-
|
|
522
|
-
"Authentication requires secure connections. "
|
|
523
|
-
"Please provide certificate paths to `--ssl-certfile`, "
|
|
524
|
-
"`--ssl-keyfile`, and `—-ssl-ca-certfile` and try again."
|
|
525
|
-
)
|
|
542
|
+
if not args.auth_list_public_keys:
|
|
543
|
+
return None
|
|
526
544
|
|
|
527
545
|
node_keys_file_path = Path(args.auth_list_public_keys)
|
|
528
546
|
if not node_keys_file_path.exists():
|
|
@@ -535,35 +553,6 @@ def _try_setup_node_authentication(
|
|
|
535
553
|
|
|
536
554
|
node_public_keys: set[bytes] = set()
|
|
537
555
|
|
|
538
|
-
try:
|
|
539
|
-
ssh_private_key = load_ssh_private_key(
|
|
540
|
-
Path(args.auth_superlink_private_key).read_bytes(),
|
|
541
|
-
None,
|
|
542
|
-
)
|
|
543
|
-
if not isinstance(ssh_private_key, ec.EllipticCurvePrivateKey):
|
|
544
|
-
raise ValueError()
|
|
545
|
-
except (ValueError, UnsupportedAlgorithm):
|
|
546
|
-
sys.exit(
|
|
547
|
-
"Error: Unable to parse the private key file in "
|
|
548
|
-
"'--auth-superlink-private-key'. Authentication requires elliptic "
|
|
549
|
-
"curve private and public key pair. Please ensure that the file "
|
|
550
|
-
"path points to a valid private key file and try again."
|
|
551
|
-
)
|
|
552
|
-
|
|
553
|
-
try:
|
|
554
|
-
ssh_public_key = load_ssh_public_key(
|
|
555
|
-
Path(args.auth_superlink_public_key).read_bytes()
|
|
556
|
-
)
|
|
557
|
-
if not isinstance(ssh_public_key, ec.EllipticCurvePublicKey):
|
|
558
|
-
raise ValueError()
|
|
559
|
-
except (ValueError, UnsupportedAlgorithm):
|
|
560
|
-
sys.exit(
|
|
561
|
-
"Error: Unable to parse the public key file in "
|
|
562
|
-
"'--auth-superlink-public-key'. Authentication requires elliptic "
|
|
563
|
-
"curve private and public key pair. Please ensure that the file "
|
|
564
|
-
"path points to a valid public key file and try again."
|
|
565
|
-
)
|
|
566
|
-
|
|
567
556
|
with open(node_keys_file_path, newline="", encoding="utf-8") as csvfile:
|
|
568
557
|
reader = csv.reader(csvfile)
|
|
569
558
|
for row in reader:
|
|
@@ -577,28 +566,27 @@ def _try_setup_node_authentication(
|
|
|
577
566
|
"file. Please ensure that the CSV file path points to a valid "
|
|
578
567
|
"known SSH public keys files and try again."
|
|
579
568
|
)
|
|
580
|
-
|
|
581
|
-
node_public_keys,
|
|
582
|
-
ssh_private_key,
|
|
583
|
-
ssh_public_key,
|
|
584
|
-
)
|
|
585
|
-
|
|
569
|
+
return node_public_keys
|
|
586
570
|
|
|
587
|
-
def _try_obtain_user_auth_config(args: argparse.Namespace) -> Optional[dict[str, Any]]:
|
|
588
|
-
if getattr(args, "user_auth_config", None) is not None:
|
|
589
|
-
with open(args.user_auth_config, encoding="utf-8") as file:
|
|
590
|
-
config: dict[str, Any] = yaml.safe_load(file)
|
|
591
|
-
return config
|
|
592
|
-
return None
|
|
593
571
|
|
|
572
|
+
def _try_obtain_exec_auth_plugin(
|
|
573
|
+
config_path: Path, verify_tls_cert: bool
|
|
574
|
+
) -> Optional[ExecAuthPlugin]:
|
|
575
|
+
# Load YAML file
|
|
576
|
+
with config_path.open("r", encoding="utf-8") as file:
|
|
577
|
+
config: dict[str, Any] = yaml.safe_load(file)
|
|
594
578
|
|
|
595
|
-
|
|
579
|
+
# Load authentication configuration
|
|
596
580
|
auth_config: dict[str, Any] = config.get("authentication", {})
|
|
597
|
-
auth_type: str = auth_config.get(
|
|
581
|
+
auth_type: str = auth_config.get(AUTH_TYPE_KEY, "")
|
|
582
|
+
|
|
583
|
+
# Load authentication plugin
|
|
598
584
|
try:
|
|
599
585
|
all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins()
|
|
600
586
|
auth_plugin_class = all_plugins[auth_type]
|
|
601
|
-
return auth_plugin_class(
|
|
587
|
+
return auth_plugin_class(
|
|
588
|
+
user_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
|
|
589
|
+
)
|
|
602
590
|
except KeyError:
|
|
603
591
|
if auth_type != "":
|
|
604
592
|
sys.exit(
|
|
@@ -681,7 +669,7 @@ def _run_fleet_api_rest(
|
|
|
681
669
|
|
|
682
670
|
from flwr.server.superlink.fleet.rest_rere.rest_api import app as fast_api_app
|
|
683
671
|
except ModuleNotFoundError:
|
|
684
|
-
|
|
672
|
+
flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
|
|
685
673
|
|
|
686
674
|
log(INFO, "Starting Flower REST server")
|
|
687
675
|
|
|
@@ -791,12 +779,12 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
791
779
|
parser.add_argument(
|
|
792
780
|
"--auth-superlink-private-key",
|
|
793
781
|
type=str,
|
|
794
|
-
help="
|
|
782
|
+
help="This argument is deprecated and will be removed in a future release.",
|
|
795
783
|
)
|
|
796
784
|
parser.add_argument(
|
|
797
785
|
"--auth-superlink-public-key",
|
|
798
786
|
type=str,
|
|
799
|
-
help="
|
|
787
|
+
help="This argument is deprecated and will be removed in a future release.",
|
|
800
788
|
)
|
|
801
789
|
|
|
802
790
|
|
flwr/server/compat/app_utils.py
CHANGED
|
@@ -28,12 +28,11 @@ from ..driver.driver import Driver
|
|
|
28
28
|
class DriverClientProxy(ClientProxy):
|
|
29
29
|
"""Flower client proxy which delegates work using the Driver API."""
|
|
30
30
|
|
|
31
|
-
def __init__(self, node_id: int, driver: Driver,
|
|
31
|
+
def __init__(self, node_id: int, driver: Driver, run_id: int):
|
|
32
32
|
super().__init__(str(node_id))
|
|
33
33
|
self.node_id = node_id
|
|
34
34
|
self.driver = driver
|
|
35
35
|
self.run_id = run_id
|
|
36
|
-
self.anonymous = anonymous
|
|
37
36
|
|
|
38
37
|
def get_properties(
|
|
39
38
|
self,
|
|
@@ -24,29 +24,32 @@ from typing import Optional, cast
|
|
|
24
24
|
import grpc
|
|
25
25
|
|
|
26
26
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
27
|
-
from flwr.common.constant import
|
|
28
|
-
|
|
27
|
+
from flwr.common.constant import (
|
|
28
|
+
SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
|
29
|
+
SUPERLINK_NODE_ID,
|
|
30
|
+
)
|
|
31
|
+
from flwr.common.grpc import create_channel, on_channel_state_change
|
|
29
32
|
from flwr.common.logger import log
|
|
30
33
|
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
|
|
31
|
-
from flwr.common.serde import
|
|
34
|
+
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
|
32
35
|
from flwr.common.typing import Run
|
|
36
|
+
from flwr.proto.message_pb2 import Message as ProtoMessage # pylint: disable=E0611
|
|
33
37
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
34
38
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
35
39
|
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
|
|
36
40
|
GetNodesRequest,
|
|
37
41
|
GetNodesResponse,
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
+
PullResMessagesRequest,
|
|
43
|
+
PullResMessagesResponse,
|
|
44
|
+
PushInsMessagesRequest,
|
|
45
|
+
PushInsMessagesResponse,
|
|
42
46
|
)
|
|
43
47
|
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
|
|
44
|
-
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
|
|
45
48
|
|
|
46
49
|
from .driver import Driver
|
|
47
50
|
|
|
48
51
|
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
|
|
49
|
-
[
|
|
52
|
+
[flwr-serverapp] Error: Not connected.
|
|
50
53
|
|
|
51
54
|
Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
|
|
52
55
|
`GrpcDriverStub` methods.
|
|
@@ -76,7 +79,7 @@ class GrpcDriver(Driver):
|
|
|
76
79
|
self._run: Optional[Run] = None
|
|
77
80
|
self._grpc_stub: Optional[ServerAppIoStub] = None
|
|
78
81
|
self._channel: Optional[grpc.Channel] = None
|
|
79
|
-
self.node = Node(node_id=
|
|
82
|
+
self.node = Node(node_id=SUPERLINK_NODE_ID)
|
|
80
83
|
self._retry_invoker = _make_simple_grpc_retry_invoker()
|
|
81
84
|
|
|
82
85
|
@property
|
|
@@ -97,9 +100,10 @@ class GrpcDriver(Driver):
|
|
|
97
100
|
insecure=(self._cert is None),
|
|
98
101
|
root_certificates=self._cert,
|
|
99
102
|
)
|
|
103
|
+
self._channel.subscribe(on_channel_state_change)
|
|
100
104
|
self._grpc_stub = ServerAppIoStub(self._channel)
|
|
101
105
|
_wrap_stub(self._grpc_stub, self._retry_invoker)
|
|
102
|
-
log(DEBUG, "[
|
|
106
|
+
log(DEBUG, "[flwr-serverapp] Connected to %s", self._addr)
|
|
103
107
|
|
|
104
108
|
def _disconnect(self) -> None:
|
|
105
109
|
"""Disconnect from the ServerAppIo API."""
|
|
@@ -110,7 +114,7 @@ class GrpcDriver(Driver):
|
|
|
110
114
|
self._channel = None
|
|
111
115
|
self._grpc_stub = None
|
|
112
116
|
channel.close()
|
|
113
|
-
log(DEBUG, "[
|
|
117
|
+
log(DEBUG, "[flwr-serverapp] Disconnected")
|
|
114
118
|
|
|
115
119
|
def set_run(self, run_id: int) -> None:
|
|
116
120
|
"""Set the run."""
|
|
@@ -193,22 +197,22 @@ class GrpcDriver(Driver):
|
|
|
193
197
|
This method takes an iterable of messages and sends each message
|
|
194
198
|
to the node specified in `dst_node_id`.
|
|
195
199
|
"""
|
|
196
|
-
# Construct
|
|
197
|
-
|
|
200
|
+
# Construct Messages
|
|
201
|
+
message_proto_list: list[ProtoMessage] = []
|
|
198
202
|
for msg in messages:
|
|
199
203
|
# Check message
|
|
200
204
|
self._check_message(msg)
|
|
201
|
-
# Convert
|
|
202
|
-
|
|
205
|
+
# Convert to proto
|
|
206
|
+
msg_proto = message_to_proto(msg)
|
|
203
207
|
# Add to list
|
|
204
|
-
|
|
208
|
+
message_proto_list.append(msg_proto)
|
|
205
209
|
# Call GrpcDriverStub method
|
|
206
|
-
res:
|
|
207
|
-
|
|
208
|
-
|
|
210
|
+
res: PushInsMessagesResponse = self._stub.PushMessages(
|
|
211
|
+
PushInsMessagesRequest(
|
|
212
|
+
messages_list=message_proto_list, run_id=cast(Run, self._run).run_id
|
|
209
213
|
)
|
|
210
214
|
)
|
|
211
|
-
return list(res.
|
|
215
|
+
return list(res.message_ids)
|
|
212
216
|
|
|
213
217
|
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
214
218
|
"""Pull messages based on message IDs.
|
|
@@ -216,14 +220,15 @@ class GrpcDriver(Driver):
|
|
|
216
220
|
This method is used to collect messages from the SuperLink that correspond to a
|
|
217
221
|
set of given message IDs.
|
|
218
222
|
"""
|
|
219
|
-
# Pull
|
|
220
|
-
res:
|
|
221
|
-
|
|
222
|
-
|
|
223
|
+
# Pull Messages
|
|
224
|
+
res: PullResMessagesResponse = self._stub.PullMessages(
|
|
225
|
+
PullResMessagesRequest(
|
|
226
|
+
message_ids=message_ids,
|
|
227
|
+
run_id=cast(Run, self._run).run_id,
|
|
223
228
|
)
|
|
224
229
|
)
|
|
225
|
-
# Convert
|
|
226
|
-
msgs = [
|
|
230
|
+
# Convert Message from Protobuf representation
|
|
231
|
+
msgs = [message_from_proto(msg_proto) for msg_proto in res.messages_list]
|
|
227
232
|
return msgs
|
|
228
233
|
|
|
229
234
|
def send_and_receive(
|
|
@@ -22,6 +22,7 @@ from typing import Optional, cast
|
|
|
22
22
|
from uuid import UUID
|
|
23
23
|
|
|
24
24
|
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
25
|
+
from flwr.common.constant import SUPERLINK_NODE_ID
|
|
25
26
|
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
26
27
|
from flwr.common.typing import Run
|
|
27
28
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
@@ -49,7 +50,7 @@ class InMemoryDriver(Driver):
|
|
|
49
50
|
self._run: Optional[Run] = None
|
|
50
51
|
self.state = state_factory.state()
|
|
51
52
|
self.pull_interval = pull_interval
|
|
52
|
-
self.node = Node(node_id=
|
|
53
|
+
self.node = Node(node_id=SUPERLINK_NODE_ID)
|
|
53
54
|
|
|
54
55
|
def _check_message(self, message: Message) -> None:
|
|
55
56
|
# Check if the message is valid
|
flwr/server/serverapp/app.py
CHANGED
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import argparse
|
|
19
|
-
import sys
|
|
20
19
|
from logging import DEBUG, ERROR, INFO
|
|
21
20
|
from pathlib import Path
|
|
22
21
|
from queue import Queue
|
|
@@ -38,6 +37,7 @@ from flwr.common.constant import (
|
|
|
38
37
|
Status,
|
|
39
38
|
SubStatus,
|
|
40
39
|
)
|
|
40
|
+
from flwr.common.exit import ExitCode, flwr_exit
|
|
41
41
|
from flwr.common.logger import (
|
|
42
42
|
log,
|
|
43
43
|
mirror_output_to_queue,
|
|
@@ -72,19 +72,18 @@ def flwr_serverapp() -> None:
|
|
|
72
72
|
|
|
73
73
|
args = _parse_args_run_flwr_serverapp().parse_args()
|
|
74
74
|
|
|
75
|
-
log(INFO, "
|
|
75
|
+
log(INFO, "Start `flwr-serverapp` process")
|
|
76
76
|
|
|
77
77
|
if not args.insecure:
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
"`flwr-serverapp` does not support TLS yet.
|
|
81
|
-
"Please use the '--insecure' flag.",
|
|
78
|
+
flwr_exit(
|
|
79
|
+
ExitCode.COMMON_TLS_NOT_SUPPORTED,
|
|
80
|
+
"`flwr-serverapp` does not support TLS yet.",
|
|
82
81
|
)
|
|
83
|
-
sys.exit(1)
|
|
84
82
|
|
|
85
83
|
log(
|
|
86
84
|
DEBUG,
|
|
87
|
-
"
|
|
85
|
+
"`flwr-serverapp` will attempt to connect to SuperLink's "
|
|
86
|
+
"ServerAppIo API at %s",
|
|
88
87
|
args.serverappio_api_address,
|
|
89
88
|
)
|
|
90
89
|
run_serverapp(
|
|
@@ -117,11 +116,13 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
117
116
|
log_uploader = None
|
|
118
117
|
success = True
|
|
119
118
|
hash_run_id = None
|
|
119
|
+
run_status = None
|
|
120
120
|
while True:
|
|
121
121
|
|
|
122
122
|
try:
|
|
123
123
|
# Pull ServerAppInputs from LinkState
|
|
124
124
|
req = PullServerAppInputsRequest()
|
|
125
|
+
log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
|
|
125
126
|
res: PullServerAppInputsResponse = driver._stub.PullServerAppInputs(req)
|
|
126
127
|
if not res.HasField("run"):
|
|
127
128
|
sleep(3)
|
|
@@ -144,7 +145,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
144
145
|
stub=driver._stub,
|
|
145
146
|
)
|
|
146
147
|
|
|
147
|
-
log(DEBUG, "
|
|
148
|
+
log(DEBUG, "[flwr-serverapp] Start FAB installation.")
|
|
148
149
|
install_from_fab(fab.content, flwr_dir=flwr_dir_, skip_prompt=True)
|
|
149
150
|
|
|
150
151
|
fab_id, fab_version = get_fab_metadata(fab.content)
|
|
@@ -165,7 +166,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
165
166
|
|
|
166
167
|
log(
|
|
167
168
|
DEBUG,
|
|
168
|
-
"
|
|
169
|
+
"[flwr-serverapp] Will load ServerApp `%s` in %s",
|
|
169
170
|
server_app_attr,
|
|
170
171
|
app_path,
|
|
171
172
|
)
|
|
@@ -191,6 +192,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
|
|
|
191
192
|
|
|
192
193
|
# Send resulting context
|
|
193
194
|
context_proto = context_to_proto(updated_context)
|
|
195
|
+
log(DEBUG, "[flwr-serverapp] Will push ServerAppOutputs")
|
|
194
196
|
out_req = PushServerAppOutputsRequest(
|
|
195
197
|
run_id=run.run_id, context=context_proto
|
|
196
198
|
)
|