flwr-nightly 1.8.0.dev20240314__py3-none-any.whl → 1.11.0.dev20240813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +7 -0
- flwr/cli/build.py +150 -0
- flwr/cli/config_utils.py +219 -0
- flwr/cli/example.py +3 -1
- flwr/cli/install.py +227 -0
- flwr/cli/new/new.py +179 -48
- flwr/cli/new/templates/app/.gitignore.tpl +160 -0
- flwr/cli/new/templates/app/README.flowertune.md.tpl +56 -0
- flwr/cli/new/templates/app/README.md.tpl +1 -5
- flwr/cli/new/templates/app/code/__init__.py.tpl +1 -1
- flwr/cli/new/templates/app/code/client.huggingface.py.tpl +65 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +56 -0
- flwr/cli/new/templates/app/code/client.mlx.py.tpl +93 -0
- flwr/cli/new/templates/app/code/client.numpy.py.tpl +3 -2
- flwr/cli/new/templates/app/code/client.pytorch.py.tpl +23 -11
- flwr/cli/new/templates/app/code/client.sklearn.py.tpl +97 -0
- flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +60 -1
- flwr/cli/new/templates/app/code/flwr_tune/__init__.py +15 -0
- flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +89 -0
- flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +126 -0
- flwr/cli/new/templates/app/code/flwr_tune/config.yaml.tpl +34 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +57 -0
- flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +59 -0
- flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +48 -0
- flwr/cli/new/templates/app/code/flwr_tune/static_config.yaml.tpl +11 -0
- flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.mlx.py.tpl +20 -0
- flwr/cli/new/templates/app/code/server.numpy.py.tpl +17 -9
- flwr/cli/new/templates/app/code/server.pytorch.py.tpl +21 -18
- flwr/cli/new/templates/app/code/server.sklearn.py.tpl +24 -0
- flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +29 -1
- flwr/cli/new/templates/app/code/task.huggingface.py.tpl +99 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/code/task.mlx.py.tpl +102 -0
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +28 -23
- flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +53 -0
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +34 -0
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +39 -0
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +25 -12
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +29 -14
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +33 -0
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +29 -14
- flwr/cli/run/run.py +168 -17
- flwr/cli/utils.py +75 -4
- flwr/client/__init__.py +6 -1
- flwr/client/app.py +239 -248
- flwr/client/client_app.py +70 -9
- flwr/client/dpfedavg_numpy_client.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +97 -0
- flwr/client/grpc_client/connection.py +18 -5
- flwr/client/grpc_rere_client/__init__.py +1 -1
- flwr/client/grpc_rere_client/client_interceptor.py +158 -0
- flwr/client/grpc_rere_client/connection.py +127 -33
- flwr/client/grpc_rere_client/grpc_adapter.py +140 -0
- flwr/client/heartbeat.py +74 -0
- flwr/client/message_handler/__init__.py +1 -1
- flwr/client/message_handler/message_handler.py +7 -7
- flwr/client/mod/__init__.py +5 -5
- flwr/client/mod/centraldp_mods.py +4 -2
- flwr/client/mod/comms_mods.py +4 -4
- flwr/client/mod/localdp_mod.py +9 -4
- flwr/client/mod/secure_aggregation/__init__.py +1 -1
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +1 -1
- flwr/client/mod/utils.py +1 -1
- flwr/client/node_state.py +60 -10
- flwr/client/node_state_tests.py +4 -3
- flwr/client/rest_client/__init__.py +1 -1
- flwr/client/rest_client/connection.py +177 -157
- flwr/client/supernode/__init__.py +26 -0
- flwr/client/supernode/app.py +464 -0
- flwr/client/typing.py +1 -0
- flwr/common/__init__.py +13 -11
- flwr/common/address.py +1 -1
- flwr/common/config.py +193 -0
- flwr/common/constant.py +42 -1
- flwr/common/context.py +26 -1
- flwr/common/date.py +1 -1
- flwr/common/dp.py +1 -1
- flwr/common/grpc.py +6 -2
- flwr/common/logger.py +79 -8
- flwr/common/message.py +167 -105
- flwr/common/object_ref.py +126 -25
- flwr/common/record/__init__.py +1 -1
- flwr/common/record/parametersrecord.py +0 -1
- flwr/common/record/recordset.py +78 -27
- flwr/common/recordset_compat.py +8 -1
- flwr/common/retry_invoker.py +25 -13
- flwr/common/secure_aggregation/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/__init__.py +1 -1
- flwr/common/secure_aggregation/crypto/shamir.py +1 -1
- flwr/common/secure_aggregation/crypto/symmetric_encryption.py +21 -2
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +1 -1
- flwr/common/secure_aggregation/quantization.py +1 -1
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/secure_aggregation/secaggplus_utils.py +1 -1
- flwr/common/serde.py +209 -3
- flwr/common/telemetry.py +25 -0
- flwr/common/typing.py +38 -0
- flwr/common/version.py +14 -0
- flwr/proto/clientappio_pb2.py +41 -0
- flwr/proto/clientappio_pb2.pyi +110 -0
- flwr/proto/clientappio_pb2_grpc.py +101 -0
- flwr/proto/clientappio_pb2_grpc.pyi +40 -0
- flwr/proto/common_pb2.py +36 -0
- flwr/proto/common_pb2.pyi +121 -0
- flwr/proto/common_pb2_grpc.py +4 -0
- flwr/proto/common_pb2_grpc.pyi +4 -0
- flwr/proto/driver_pb2.py +26 -19
- flwr/proto/driver_pb2.pyi +34 -0
- flwr/proto/driver_pb2_grpc.py +70 -0
- flwr/proto/driver_pb2_grpc.pyi +28 -0
- flwr/proto/exec_pb2.py +43 -0
- flwr/proto/exec_pb2.pyi +95 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +29 -23
- flwr/proto/fleet_pb2.pyi +33 -0
- flwr/proto/fleet_pb2_grpc.py +102 -0
- flwr/proto/fleet_pb2_grpc.pyi +35 -0
- flwr/proto/grpcadapter_pb2.py +32 -0
- flwr/proto/grpcadapter_pb2.pyi +43 -0
- flwr/proto/grpcadapter_pb2_grpc.py +66 -0
- flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
- flwr/proto/message_pb2.py +41 -0
- flwr/proto/message_pb2.pyi +122 -0
- flwr/proto/message_pb2_grpc.py +4 -0
- flwr/proto/message_pb2_grpc.pyi +4 -0
- flwr/proto/run_pb2.py +35 -0
- flwr/proto/run_pb2.pyi +76 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/proto/task_pb2.py +7 -8
- flwr/proto/task_pb2.pyi +8 -5
- flwr/server/__init__.py +4 -8
- flwr/server/app.py +298 -350
- flwr/server/compat/app.py +6 -57
- flwr/server/compat/app_utils.py +5 -4
- flwr/server/compat/driver_client_proxy.py +29 -48
- flwr/server/compat/legacy_context.py +5 -4
- flwr/server/driver/__init__.py +2 -0
- flwr/server/driver/driver.py +22 -132
- flwr/server/driver/grpc_driver.py +224 -74
- flwr/server/driver/inmemory_driver.py +183 -0
- flwr/server/history.py +20 -20
- flwr/server/run_serverapp.py +121 -34
- flwr/server/server.py +11 -7
- flwr/server/server_app.py +59 -10
- flwr/server/serverapp_components.py +52 -0
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/strategy/bulyan.py +1 -1
- flwr/server/strategy/dp_adaptive_clipping.py +3 -3
- flwr/server/strategy/dp_fixed_clipping.py +4 -3
- flwr/server/strategy/dpfedavg_adaptive.py +1 -1
- flwr/server/strategy/dpfedavg_fixed.py +1 -1
- flwr/server/strategy/fedadagrad.py +1 -1
- flwr/server/strategy/fedadam.py +1 -1
- flwr/server/strategy/fedavg_android.py +1 -1
- flwr/server/strategy/fedavgm.py +1 -1
- flwr/server/strategy/fedmedian.py +1 -1
- flwr/server/strategy/fedopt.py +1 -1
- flwr/server/strategy/fedprox.py +1 -1
- flwr/server/strategy/fedxgb_bagging.py +1 -1
- flwr/server/strategy/fedxgb_cyclic.py +1 -1
- flwr/server/strategy/fedxgb_nn_avg.py +1 -1
- flwr/server/strategy/fedyogi.py +1 -1
- flwr/server/strategy/krum.py +1 -1
- flwr/server/strategy/qfedavg.py +1 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/driver_grpc.py +1 -1
- flwr/server/superlink/driver/driver_servicer.py +51 -4
- flwr/server/superlink/ffs/__init__.py +24 -0
- flwr/server/superlink/ffs/disk_ffs.py +104 -0
- flwr/server/superlink/ffs/ffs.py +79 -0
- flwr/server/superlink/fleet/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +1 -1
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +8 -2
- flwr/server/superlink/fleet/grpc_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +30 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +214 -0
- flwr/server/superlink/fleet/message_handler/__init__.py +1 -1
- flwr/server/superlink/fleet/message_handler/message_handler.py +42 -2
- flwr/server/superlink/fleet/rest_rere/__init__.py +1 -1
- flwr/server/superlink/fleet/rest_rere/rest_api.py +59 -1
- flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
- flwr/server/superlink/fleet/vce/backend/backend.py +5 -5
- flwr/server/superlink/fleet/vce/backend/raybackend.py +53 -56
- flwr/server/superlink/fleet/vce/vce_api.py +190 -127
- flwr/server/superlink/state/__init__.py +1 -1
- flwr/server/superlink/state/in_memory_state.py +159 -42
- flwr/server/superlink/state/sqlite_state.py +243 -39
- flwr/server/superlink/state/state.py +81 -6
- flwr/server/superlink/state/state_factory.py +11 -2
- flwr/server/superlink/state/utils.py +62 -0
- flwr/server/typing.py +2 -0
- flwr/server/utils/__init__.py +1 -1
- flwr/server/utils/tensorboard.py +1 -1
- flwr/server/utils/validator.py +23 -9
- flwr/server/workflow/default_workflows.py +67 -25
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -6
- flwr/simulation/__init__.py +7 -4
- flwr/simulation/app.py +67 -36
- flwr/simulation/ray_transport/__init__.py +1 -1
- flwr/simulation/ray_transport/ray_actor.py +20 -46
- flwr/simulation/ray_transport/ray_client_proxy.py +36 -16
- flwr/simulation/run_simulation.py +308 -92
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +184 -0
- flwr/superexec/deployment.py +185 -0
- flwr/superexec/exec_grpc.py +55 -0
- flwr/superexec/exec_servicer.py +70 -0
- flwr/superexec/executor.py +75 -0
- flwr/superexec/simulation.py +193 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/METADATA +10 -6
- flwr_nightly-1.11.0.dev20240813.dist-info/RECORD +288 -0
- flwr_nightly-1.11.0.dev20240813.dist-info/entry_points.txt +10 -0
- flwr/cli/flower_toml.py +0 -140
- flwr/cli/new/templates/app/flower.toml.tpl +0 -13
- flwr/cli/new/templates/app/requirements.numpy.txt.tpl +0 -2
- flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +0 -4
- flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +0 -4
- flwr_nightly-1.8.0.dev20240314.dist-info/RECORD +0 -211
- flwr_nightly-1.8.0.dev20240314.dist-info/entry_points.txt +0 -9
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.8.0.dev20240314.dist-info → flwr_nightly-1.11.0.dev20240813.dist-info}/WHEEL +0 -0
flwr/server/app.py
CHANGED
|
@@ -15,30 +15,41 @@
|
|
|
15
15
|
"""Flower server app."""
|
|
16
16
|
|
|
17
17
|
import argparse
|
|
18
|
-
import
|
|
18
|
+
import csv
|
|
19
19
|
import importlib.util
|
|
20
20
|
import sys
|
|
21
21
|
import threading
|
|
22
|
-
from logging import
|
|
22
|
+
from logging import INFO, WARN
|
|
23
23
|
from os.path import isfile
|
|
24
24
|
from pathlib import Path
|
|
25
|
-
from typing import
|
|
25
|
+
from typing import Optional, Sequence, Set, Tuple
|
|
26
26
|
|
|
27
27
|
import grpc
|
|
28
|
+
from cryptography.exceptions import UnsupportedAlgorithm
|
|
29
|
+
from cryptography.hazmat.primitives.asymmetric import ec
|
|
30
|
+
from cryptography.hazmat.primitives.serialization import (
|
|
31
|
+
load_ssh_private_key,
|
|
32
|
+
load_ssh_public_key,
|
|
33
|
+
)
|
|
28
34
|
|
|
29
35
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
30
36
|
from flwr.common.address import parse_address
|
|
31
37
|
from flwr.common.constant import (
|
|
32
38
|
MISSING_EXTRA_REST,
|
|
39
|
+
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
33
40
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
34
41
|
TRANSPORT_TYPE_REST,
|
|
35
|
-
TRANSPORT_TYPE_VCE,
|
|
36
42
|
)
|
|
37
43
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
38
44
|
from flwr.common.logger import log
|
|
45
|
+
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
46
|
+
private_key_to_bytes,
|
|
47
|
+
public_key_to_bytes,
|
|
48
|
+
)
|
|
39
49
|
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
|
40
50
|
add_FleetServicer_to_server,
|
|
41
51
|
)
|
|
52
|
+
from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
|
|
42
53
|
|
|
43
54
|
from .client_manager import ClientManager
|
|
44
55
|
from .history import History
|
|
@@ -46,12 +57,13 @@ from .server import Server, init_defaults, run_fl
|
|
|
46
57
|
from .server_config import ServerConfig
|
|
47
58
|
from .strategy import Strategy
|
|
48
59
|
from .superlink.driver.driver_grpc import run_driver_api_grpc
|
|
60
|
+
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
|
|
49
61
|
from .superlink.fleet.grpc_bidi.grpc_server import (
|
|
50
62
|
generic_create_grpc_server,
|
|
51
63
|
start_grpc_server,
|
|
52
64
|
)
|
|
53
65
|
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
|
|
54
|
-
from .superlink.fleet.
|
|
66
|
+
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
|
|
55
67
|
from .superlink.state import StateFactory
|
|
56
68
|
|
|
57
69
|
ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
@@ -181,127 +193,17 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
|
181
193
|
return hist
|
|
182
194
|
|
|
183
195
|
|
|
184
|
-
def run_driver_api() -> None:
|
|
185
|
-
"""Run Flower server (Driver API)."""
|
|
186
|
-
log(INFO, "Starting Flower server (Driver API)")
|
|
187
|
-
event(EventType.RUN_DRIVER_API_ENTER)
|
|
188
|
-
args = _parse_args_run_driver_api().parse_args()
|
|
189
|
-
|
|
190
|
-
# Parse IP address
|
|
191
|
-
parsed_address = parse_address(args.driver_api_address)
|
|
192
|
-
if not parsed_address:
|
|
193
|
-
sys.exit(f"Driver IP address ({args.driver_api_address}) cannot be parsed.")
|
|
194
|
-
host, port, is_v6 = parsed_address
|
|
195
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
196
|
-
|
|
197
|
-
# Obtain certificates
|
|
198
|
-
certificates = _try_obtain_certificates(args)
|
|
199
|
-
|
|
200
|
-
# Initialize StateFactory
|
|
201
|
-
state_factory = StateFactory(args.database)
|
|
202
|
-
|
|
203
|
-
# Start server
|
|
204
|
-
grpc_server: grpc.Server = run_driver_api_grpc(
|
|
205
|
-
address=address,
|
|
206
|
-
state_factory=state_factory,
|
|
207
|
-
certificates=certificates,
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
# Graceful shutdown
|
|
211
|
-
register_exit_handlers(
|
|
212
|
-
event_type=EventType.RUN_DRIVER_API_LEAVE,
|
|
213
|
-
grpc_servers=[grpc_server],
|
|
214
|
-
bckg_threads=[],
|
|
215
|
-
)
|
|
216
|
-
|
|
217
|
-
# Block
|
|
218
|
-
grpc_server.wait_for_termination()
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
def run_fleet_api() -> None:
|
|
222
|
-
"""Run Flower server (Fleet API)."""
|
|
223
|
-
log(INFO, "Starting Flower server (Fleet API)")
|
|
224
|
-
event(EventType.RUN_FLEET_API_ENTER)
|
|
225
|
-
args = _parse_args_run_fleet_api().parse_args()
|
|
226
|
-
|
|
227
|
-
# Obtain certificates
|
|
228
|
-
certificates = _try_obtain_certificates(args)
|
|
229
|
-
|
|
230
|
-
# Initialize StateFactory
|
|
231
|
-
state_factory = StateFactory(args.database)
|
|
232
|
-
|
|
233
|
-
grpc_servers = []
|
|
234
|
-
bckg_threads = []
|
|
235
|
-
|
|
236
|
-
# Start Fleet API
|
|
237
|
-
if args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
238
|
-
if (
|
|
239
|
-
importlib.util.find_spec("requests")
|
|
240
|
-
and importlib.util.find_spec("starlette")
|
|
241
|
-
and importlib.util.find_spec("uvicorn")
|
|
242
|
-
) is None:
|
|
243
|
-
sys.exit(MISSING_EXTRA_REST)
|
|
244
|
-
address_arg = args.rest_fleet_api_address
|
|
245
|
-
parsed_address = parse_address(address_arg)
|
|
246
|
-
if not parsed_address:
|
|
247
|
-
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
248
|
-
host, port, _ = parsed_address
|
|
249
|
-
fleet_thread = threading.Thread(
|
|
250
|
-
target=_run_fleet_api_rest,
|
|
251
|
-
args=(
|
|
252
|
-
host,
|
|
253
|
-
port,
|
|
254
|
-
args.ssl_keyfile,
|
|
255
|
-
args.ssl_certfile,
|
|
256
|
-
state_factory,
|
|
257
|
-
args.rest_fleet_api_workers,
|
|
258
|
-
),
|
|
259
|
-
)
|
|
260
|
-
fleet_thread.start()
|
|
261
|
-
bckg_threads.append(fleet_thread)
|
|
262
|
-
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
263
|
-
address_arg = args.grpc_rere_fleet_api_address
|
|
264
|
-
parsed_address = parse_address(address_arg)
|
|
265
|
-
if not parsed_address:
|
|
266
|
-
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
267
|
-
host, port, is_v6 = parsed_address
|
|
268
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
269
|
-
fleet_server = _run_fleet_api_grpc_rere(
|
|
270
|
-
address=address,
|
|
271
|
-
state_factory=state_factory,
|
|
272
|
-
certificates=certificates,
|
|
273
|
-
)
|
|
274
|
-
grpc_servers.append(fleet_server)
|
|
275
|
-
else:
|
|
276
|
-
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
|
277
|
-
|
|
278
|
-
# Graceful shutdown
|
|
279
|
-
register_exit_handlers(
|
|
280
|
-
event_type=EventType.RUN_FLEET_API_LEAVE,
|
|
281
|
-
grpc_servers=grpc_servers,
|
|
282
|
-
bckg_threads=bckg_threads,
|
|
283
|
-
)
|
|
284
|
-
|
|
285
|
-
# Block
|
|
286
|
-
if len(grpc_servers) > 0:
|
|
287
|
-
grpc_servers[0].wait_for_termination()
|
|
288
|
-
elif len(bckg_threads) > 0:
|
|
289
|
-
bckg_threads[0].join()
|
|
290
|
-
|
|
291
|
-
|
|
292
196
|
# pylint: disable=too-many-branches, too-many-locals, too-many-statements
|
|
293
197
|
def run_superlink() -> None:
|
|
294
|
-
"""Run Flower
|
|
295
|
-
log(INFO, "Starting Flower
|
|
198
|
+
"""Run Flower SuperLink (Driver API and Fleet API)."""
|
|
199
|
+
log(INFO, "Starting Flower SuperLink")
|
|
200
|
+
|
|
296
201
|
event(EventType.RUN_SUPERLINK_ENTER)
|
|
202
|
+
|
|
297
203
|
args = _parse_args_run_superlink().parse_args()
|
|
298
204
|
|
|
299
205
|
# Parse IP address
|
|
300
|
-
|
|
301
|
-
if not parsed_address:
|
|
302
|
-
sys.exit(f"Driver IP address ({args.driver_api_address}) cannot be parsed.")
|
|
303
|
-
host, port, is_v6 = parsed_address
|
|
304
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
206
|
+
driver_address, _, _ = _format_address(args.driver_api_address)
|
|
305
207
|
|
|
306
208
|
# Obtain certificates
|
|
307
209
|
certificates = _try_obtain_certificates(args)
|
|
@@ -311,13 +213,35 @@ def run_superlink() -> None:
|
|
|
311
213
|
|
|
312
214
|
# Start Driver API
|
|
313
215
|
driver_server: grpc.Server = run_driver_api_grpc(
|
|
314
|
-
address=
|
|
216
|
+
address=driver_address,
|
|
315
217
|
state_factory=state_factory,
|
|
316
218
|
certificates=certificates,
|
|
317
219
|
)
|
|
318
220
|
|
|
319
221
|
grpc_servers = [driver_server]
|
|
320
222
|
bckg_threads = []
|
|
223
|
+
if not args.fleet_api_address:
|
|
224
|
+
if args.fleet_api_type in [
|
|
225
|
+
TRANSPORT_TYPE_GRPC_RERE,
|
|
226
|
+
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
227
|
+
]:
|
|
228
|
+
args.fleet_api_address = ADDRESS_FLEET_API_GRPC_RERE
|
|
229
|
+
elif args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
230
|
+
args.fleet_api_address = ADDRESS_FLEET_API_REST
|
|
231
|
+
|
|
232
|
+
fleet_address, host, port = _format_address(args.fleet_api_address)
|
|
233
|
+
|
|
234
|
+
num_workers = args.fleet_api_num_workers
|
|
235
|
+
if num_workers != 1:
|
|
236
|
+
log(
|
|
237
|
+
WARN,
|
|
238
|
+
"The Fleet API currently supports only 1 worker. "
|
|
239
|
+
"You have specified %d workers. "
|
|
240
|
+
"Support for multiple workers will be added in future releases. "
|
|
241
|
+
"Proceeding with a single worker.",
|
|
242
|
+
args.fleet_api_num_workers,
|
|
243
|
+
)
|
|
244
|
+
num_workers = 1
|
|
321
245
|
|
|
322
246
|
# Start Fleet API
|
|
323
247
|
if args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
@@ -327,48 +251,60 @@ def run_superlink() -> None:
|
|
|
327
251
|
and importlib.util.find_spec("uvicorn")
|
|
328
252
|
) is None:
|
|
329
253
|
sys.exit(MISSING_EXTRA_REST)
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
254
|
+
|
|
255
|
+
_, ssl_certfile, ssl_keyfile = (
|
|
256
|
+
certificates if certificates is not None else (None, None, None)
|
|
257
|
+
)
|
|
258
|
+
|
|
335
259
|
fleet_thread = threading.Thread(
|
|
336
260
|
target=_run_fleet_api_rest,
|
|
337
261
|
args=(
|
|
338
262
|
host,
|
|
339
263
|
port,
|
|
340
|
-
|
|
341
|
-
|
|
264
|
+
ssl_keyfile,
|
|
265
|
+
ssl_certfile,
|
|
342
266
|
state_factory,
|
|
343
|
-
|
|
267
|
+
num_workers,
|
|
344
268
|
),
|
|
345
269
|
)
|
|
346
270
|
fleet_thread.start()
|
|
347
271
|
bckg_threads.append(fleet_thread)
|
|
348
272
|
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
if not
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
273
|
+
maybe_keys = _try_setup_client_authentication(args, certificates)
|
|
274
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
|
|
275
|
+
if maybe_keys is not None:
|
|
276
|
+
(
|
|
277
|
+
client_public_keys,
|
|
278
|
+
server_private_key,
|
|
279
|
+
server_public_key,
|
|
280
|
+
) = maybe_keys
|
|
281
|
+
state = state_factory.state()
|
|
282
|
+
state.store_client_public_keys(client_public_keys)
|
|
283
|
+
state.store_server_private_public_key(
|
|
284
|
+
private_key_to_bytes(server_private_key),
|
|
285
|
+
public_key_to_bytes(server_public_key),
|
|
286
|
+
)
|
|
287
|
+
log(
|
|
288
|
+
INFO,
|
|
289
|
+
"Client authentication enabled with %d known public keys",
|
|
290
|
+
len(client_public_keys),
|
|
291
|
+
)
|
|
292
|
+
interceptors = [AuthenticateServerInterceptor(state)]
|
|
293
|
+
|
|
355
294
|
fleet_server = _run_fleet_api_grpc_rere(
|
|
356
|
-
address=
|
|
295
|
+
address=fleet_address,
|
|
357
296
|
state_factory=state_factory,
|
|
358
297
|
certificates=certificates,
|
|
298
|
+
interceptors=interceptors,
|
|
359
299
|
)
|
|
360
300
|
grpc_servers.append(fleet_server)
|
|
361
|
-
elif args.fleet_api_type ==
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
num_supernodes=args.num_supernodes,
|
|
365
|
-
client_app_attr=args.client_app,
|
|
366
|
-
backend_name=args.backend,
|
|
367
|
-
backend_config_json_stream=args.backend_config,
|
|
368
|
-
app_dir=args.app_dir,
|
|
301
|
+
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER:
|
|
302
|
+
fleet_server = _run_fleet_api_grpc_adapter(
|
|
303
|
+
address=fleet_address,
|
|
369
304
|
state_factory=state_factory,
|
|
370
|
-
|
|
305
|
+
certificates=certificates,
|
|
371
306
|
)
|
|
307
|
+
grpc_servers.append(fleet_server)
|
|
372
308
|
else:
|
|
373
309
|
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
|
374
310
|
|
|
@@ -388,33 +324,164 @@ def run_superlink() -> None:
|
|
|
388
324
|
driver_server.wait_for_termination(timeout=1)
|
|
389
325
|
|
|
390
326
|
|
|
327
|
+
def _format_address(address: str) -> Tuple[str, str, int]:
|
|
328
|
+
parsed_address = parse_address(address)
|
|
329
|
+
if not parsed_address:
|
|
330
|
+
sys.exit(
|
|
331
|
+
f"Address ({address}) cannot be parsed (expected: URL or IPv4 or IPv6)."
|
|
332
|
+
)
|
|
333
|
+
host, port, is_v6 = parsed_address
|
|
334
|
+
return (f"[{host}]:{port}" if is_v6 else f"{host}:{port}", host, port)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def _try_setup_client_authentication(
|
|
338
|
+
args: argparse.Namespace,
|
|
339
|
+
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
340
|
+
) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
|
|
341
|
+
if (
|
|
342
|
+
not args.auth_list_public_keys
|
|
343
|
+
and not args.auth_superlink_private_key
|
|
344
|
+
and not args.auth_superlink_public_key
|
|
345
|
+
):
|
|
346
|
+
return None
|
|
347
|
+
|
|
348
|
+
if (
|
|
349
|
+
not args.auth_list_public_keys
|
|
350
|
+
or not args.auth_superlink_private_key
|
|
351
|
+
or not args.auth_superlink_public_key
|
|
352
|
+
):
|
|
353
|
+
sys.exit(
|
|
354
|
+
"Authentication requires providing file paths for "
|
|
355
|
+
"'--auth-list-public-keys', '--auth-superlink-private-key' and "
|
|
356
|
+
"'--auth-superlink-public-key'. Provide all three to enable authentication."
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
if certificates is None:
|
|
360
|
+
sys.exit(
|
|
361
|
+
"Authentication requires secure connections. "
|
|
362
|
+
"Please provide certificate paths to `--ssl-certfile`, "
|
|
363
|
+
"`--ssl-keyfile`, and `—-ssl-ca-certfile` and try again."
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
client_keys_file_path = Path(args.auth_list_public_keys)
|
|
367
|
+
if not client_keys_file_path.exists():
|
|
368
|
+
sys.exit(
|
|
369
|
+
"The provided path to the known public keys CSV file does not exist: "
|
|
370
|
+
f"{client_keys_file_path}. "
|
|
371
|
+
"Please provide the CSV file path containing known public keys "
|
|
372
|
+
"to '--auth-list-public-keys'."
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
client_public_keys: Set[bytes] = set()
|
|
376
|
+
|
|
377
|
+
try:
|
|
378
|
+
ssh_private_key = load_ssh_private_key(
|
|
379
|
+
Path(args.auth_superlink_private_key).read_bytes(),
|
|
380
|
+
None,
|
|
381
|
+
)
|
|
382
|
+
if not isinstance(ssh_private_key, ec.EllipticCurvePrivateKey):
|
|
383
|
+
raise ValueError()
|
|
384
|
+
except (ValueError, UnsupportedAlgorithm):
|
|
385
|
+
sys.exit(
|
|
386
|
+
"Error: Unable to parse the private key file in "
|
|
387
|
+
"'--auth-superlink-private-key'. Authentication requires elliptic "
|
|
388
|
+
"curve private and public key pair. Please ensure that the file "
|
|
389
|
+
"path points to a valid private key file and try again."
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
try:
|
|
393
|
+
ssh_public_key = load_ssh_public_key(
|
|
394
|
+
Path(args.auth_superlink_public_key).read_bytes()
|
|
395
|
+
)
|
|
396
|
+
if not isinstance(ssh_public_key, ec.EllipticCurvePublicKey):
|
|
397
|
+
raise ValueError()
|
|
398
|
+
except (ValueError, UnsupportedAlgorithm):
|
|
399
|
+
sys.exit(
|
|
400
|
+
"Error: Unable to parse the public key file in "
|
|
401
|
+
"'--auth-superlink-public-key'. Authentication requires elliptic "
|
|
402
|
+
"curve private and public key pair. Please ensure that the file "
|
|
403
|
+
"path points to a valid public key file and try again."
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile:
|
|
407
|
+
reader = csv.reader(csvfile)
|
|
408
|
+
for row in reader:
|
|
409
|
+
for element in row:
|
|
410
|
+
public_key = load_ssh_public_key(element.encode())
|
|
411
|
+
if isinstance(public_key, ec.EllipticCurvePublicKey):
|
|
412
|
+
client_public_keys.add(public_key_to_bytes(public_key))
|
|
413
|
+
else:
|
|
414
|
+
sys.exit(
|
|
415
|
+
"Error: Unable to parse the public keys in the CSV "
|
|
416
|
+
"file. Please ensure that the CSV file path points to a valid "
|
|
417
|
+
"known SSH public keys files and try again."
|
|
418
|
+
)
|
|
419
|
+
return (
|
|
420
|
+
client_public_keys,
|
|
421
|
+
ssh_private_key,
|
|
422
|
+
ssh_public_key,
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
|
|
391
426
|
def _try_obtain_certificates(
|
|
392
427
|
args: argparse.Namespace,
|
|
393
428
|
) -> Optional[Tuple[bytes, bytes, bytes]]:
|
|
394
429
|
# Obtain certificates
|
|
395
430
|
if args.insecure:
|
|
396
431
|
log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
|
|
397
|
-
|
|
432
|
+
return None
|
|
398
433
|
# Check if certificates are provided
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
434
|
+
if args.fleet_api_type in [TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_GRPC_ADAPTER]:
|
|
435
|
+
if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
|
|
436
|
+
if not isfile(args.ssl_ca_certfile):
|
|
437
|
+
sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
|
|
438
|
+
if not isfile(args.ssl_certfile):
|
|
439
|
+
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
|
|
440
|
+
if not isfile(args.ssl_keyfile):
|
|
441
|
+
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
|
|
442
|
+
certificates = (
|
|
443
|
+
Path(args.ssl_ca_certfile).read_bytes(), # CA certificate
|
|
444
|
+
Path(args.ssl_certfile).read_bytes(), # server certificate
|
|
445
|
+
Path(args.ssl_keyfile).read_bytes(), # server private key
|
|
446
|
+
)
|
|
447
|
+
return certificates
|
|
448
|
+
if args.ssl_certfile or args.ssl_keyfile or args.ssl_ca_certfile:
|
|
449
|
+
sys.exit(
|
|
450
|
+
"You need to provide valid file paths to `--ssl-certfile`, "
|
|
451
|
+
"`--ssl-keyfile`, and `—-ssl-ca-certfile` to create a secure "
|
|
452
|
+
"connection in Fleet API server (gRPC-rere)."
|
|
453
|
+
)
|
|
454
|
+
if args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
455
|
+
if args.ssl_certfile and args.ssl_keyfile:
|
|
456
|
+
if not isfile(args.ssl_certfile):
|
|
457
|
+
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
|
|
458
|
+
if not isfile(args.ssl_keyfile):
|
|
459
|
+
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
|
|
460
|
+
certificates = (
|
|
461
|
+
b"",
|
|
462
|
+
Path(args.ssl_certfile).read_bytes(), # server certificate
|
|
463
|
+
Path(args.ssl_keyfile).read_bytes(), # server private key
|
|
464
|
+
)
|
|
465
|
+
return certificates
|
|
466
|
+
if args.ssl_certfile or args.ssl_keyfile:
|
|
467
|
+
sys.exit(
|
|
468
|
+
"You need to provide valid file paths to `--ssl-certfile` "
|
|
469
|
+
"and `--ssl-keyfile` to create a secure connection "
|
|
470
|
+
"in Fleet API server (REST, experimental)."
|
|
471
|
+
)
|
|
472
|
+
sys.exit(
|
|
473
|
+
"Certificates are required unless running in insecure mode. "
|
|
474
|
+
"Please provide certificate paths to `--ssl-certfile`, "
|
|
475
|
+
"`--ssl-keyfile`, and `—-ssl-ca-certfile` or run the server "
|
|
476
|
+
"in insecure mode using '--insecure' if you understand the risks."
|
|
477
|
+
)
|
|
412
478
|
|
|
413
479
|
|
|
414
480
|
def _run_fleet_api_grpc_rere(
|
|
415
481
|
address: str,
|
|
416
482
|
state_factory: StateFactory,
|
|
417
483
|
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
484
|
+
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
|
|
418
485
|
) -> grpc.Server:
|
|
419
486
|
"""Run Fleet API (gRPC, request-response)."""
|
|
420
487
|
# Create Fleet API gRPC server
|
|
@@ -427,6 +494,7 @@ def _run_fleet_api_grpc_rere(
|
|
|
427
494
|
server_address=address,
|
|
428
495
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
429
496
|
certificates=certificates,
|
|
497
|
+
interceptors=interceptors,
|
|
430
498
|
)
|
|
431
499
|
|
|
432
500
|
log(INFO, "Flower ECE: Starting Fleet API (gRPC-rere) on %s", address)
|
|
@@ -435,28 +503,29 @@ def _run_fleet_api_grpc_rere(
|
|
|
435
503
|
return fleet_grpc_server
|
|
436
504
|
|
|
437
505
|
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
num_supernodes: int,
|
|
441
|
-
client_app_attr: str,
|
|
442
|
-
backend_name: str,
|
|
443
|
-
backend_config_json_stream: str,
|
|
444
|
-
app_dir: str,
|
|
506
|
+
def _run_fleet_api_grpc_adapter(
|
|
507
|
+
address: str,
|
|
445
508
|
state_factory: StateFactory,
|
|
446
|
-
|
|
447
|
-
) ->
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
num_supernodes=num_supernodes,
|
|
452
|
-
client_app_attr=client_app_attr,
|
|
453
|
-
backend_name=backend_name,
|
|
454
|
-
backend_config_json_stream=backend_config_json_stream,
|
|
509
|
+
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
510
|
+
) -> grpc.Server:
|
|
511
|
+
"""Run Fleet API (GrpcAdapter)."""
|
|
512
|
+
# Create Fleet API gRPC server
|
|
513
|
+
fleet_servicer = GrpcAdapterServicer(
|
|
455
514
|
state_factory=state_factory,
|
|
456
|
-
|
|
457
|
-
|
|
515
|
+
)
|
|
516
|
+
fleet_add_servicer_to_server_fn = add_GrpcAdapterServicer_to_server
|
|
517
|
+
fleet_grpc_server = generic_create_grpc_server(
|
|
518
|
+
servicer_and_add_fn=(fleet_servicer, fleet_add_servicer_to_server_fn),
|
|
519
|
+
server_address=address,
|
|
520
|
+
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
521
|
+
certificates=certificates,
|
|
458
522
|
)
|
|
459
523
|
|
|
524
|
+
log(INFO, "Flower ECE: Starting Fleet API (GrpcAdapter) on %s", address)
|
|
525
|
+
fleet_grpc_server.start()
|
|
526
|
+
|
|
527
|
+
return fleet_grpc_server
|
|
528
|
+
|
|
460
529
|
|
|
461
530
|
# pylint: disable=import-outside-toplevel,too-many-arguments
|
|
462
531
|
def _run_fleet_api_rest(
|
|
@@ -465,7 +534,7 @@ def _run_fleet_api_rest(
|
|
|
465
534
|
ssl_keyfile: Optional[str],
|
|
466
535
|
ssl_certfile: Optional[str],
|
|
467
536
|
state_factory: StateFactory,
|
|
468
|
-
|
|
537
|
+
num_workers: int,
|
|
469
538
|
) -> None:
|
|
470
539
|
"""Run Driver API (REST-based)."""
|
|
471
540
|
try:
|
|
@@ -474,25 +543,12 @@ def _run_fleet_api_rest(
|
|
|
474
543
|
from flwr.server.superlink.fleet.rest_rere.rest_api import app as fast_api_app
|
|
475
544
|
except ModuleNotFoundError:
|
|
476
545
|
sys.exit(MISSING_EXTRA_REST)
|
|
477
|
-
|
|
478
|
-
raise ValueError(
|
|
479
|
-
f"The supported number of workers for the Fleet API (REST server) is "
|
|
480
|
-
f"1. Instead given {workers}. The functionality of >1 workers will be "
|
|
481
|
-
f"added in the future releases."
|
|
482
|
-
)
|
|
546
|
+
|
|
483
547
|
log(INFO, "Starting Flower REST server")
|
|
484
548
|
|
|
485
549
|
# See: https://www.starlette.io/applications/#accessing-the-app-instance
|
|
486
550
|
fast_api_app.state.STATE_FACTORY = state_factory
|
|
487
551
|
|
|
488
|
-
validation_exceptions = _validate_ssl_files(
|
|
489
|
-
ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile
|
|
490
|
-
)
|
|
491
|
-
if any(validation_exceptions):
|
|
492
|
-
# Starting with 3.11 we can use ExceptionGroup but for now
|
|
493
|
-
# this seems to be the reasonable approach.
|
|
494
|
-
raise ValueError(validation_exceptions)
|
|
495
|
-
|
|
496
552
|
uvicorn.run(
|
|
497
553
|
app="flwr.server.superlink.fleet.rest_rere.rest_api:app",
|
|
498
554
|
port=port,
|
|
@@ -501,76 +557,14 @@ def _run_fleet_api_rest(
|
|
|
501
557
|
access_log=True,
|
|
502
558
|
ssl_keyfile=ssl_keyfile,
|
|
503
559
|
ssl_certfile=ssl_certfile,
|
|
504
|
-
workers=
|
|
505
|
-
)
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
def _validate_ssl_files(
|
|
509
|
-
ssl_keyfile: Optional[str], ssl_certfile: Optional[str]
|
|
510
|
-
) -> List[ValueError]:
|
|
511
|
-
validation_exceptions = []
|
|
512
|
-
|
|
513
|
-
if ssl_keyfile is not None and not isfile(ssl_keyfile):
|
|
514
|
-
msg = "Path argument `--ssl-keyfile` does not point to a file."
|
|
515
|
-
log(ERROR, msg)
|
|
516
|
-
validation_exceptions.append(ValueError(msg))
|
|
517
|
-
|
|
518
|
-
if ssl_certfile is not None and not isfile(ssl_certfile):
|
|
519
|
-
msg = "Path argument `--ssl-certfile` does not point to a file."
|
|
520
|
-
log(ERROR, msg)
|
|
521
|
-
validation_exceptions.append(ValueError(msg))
|
|
522
|
-
|
|
523
|
-
if not bool(ssl_keyfile) == bool(ssl_certfile):
|
|
524
|
-
msg = (
|
|
525
|
-
"When setting one of `--ssl-keyfile` and "
|
|
526
|
-
"`--ssl-certfile`, both have to be used."
|
|
527
|
-
)
|
|
528
|
-
log(ERROR, msg)
|
|
529
|
-
validation_exceptions.append(ValueError(msg))
|
|
530
|
-
|
|
531
|
-
return validation_exceptions
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
def _parse_args_run_driver_api() -> argparse.ArgumentParser:
|
|
535
|
-
"""Parse command line arguments for Driver API."""
|
|
536
|
-
parser = argparse.ArgumentParser(
|
|
537
|
-
description="Start a Flower Driver API server. "
|
|
538
|
-
"This server will be responsible for "
|
|
539
|
-
"receiving TaskIns from the Driver script and "
|
|
540
|
-
"sending them to the Fleet API. Once the client nodes "
|
|
541
|
-
"are done, they will send the TaskRes back to this Driver API server (through"
|
|
542
|
-
" the Fleet API) which will then send them back to the Driver script.",
|
|
543
|
-
)
|
|
544
|
-
|
|
545
|
-
_add_args_common(parser=parser)
|
|
546
|
-
_add_args_driver_api(parser=parser)
|
|
547
|
-
|
|
548
|
-
return parser
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
def _parse_args_run_fleet_api() -> argparse.ArgumentParser:
|
|
552
|
-
"""Parse command line arguments for Fleet API."""
|
|
553
|
-
parser = argparse.ArgumentParser(
|
|
554
|
-
description="Start a Flower Fleet API server."
|
|
555
|
-
"This server will be responsible for "
|
|
556
|
-
"sending TaskIns (received from the Driver API) to the client nodes "
|
|
557
|
-
"and of receiving TaskRes sent back from those same client nodes once "
|
|
558
|
-
"they are done. Then, this Fleet API server can send those "
|
|
559
|
-
"TaskRes back to the Driver API.",
|
|
560
|
+
workers=num_workers,
|
|
560
561
|
)
|
|
561
562
|
|
|
562
|
-
_add_args_common(parser=parser)
|
|
563
|
-
_add_args_fleet_api(parser=parser)
|
|
564
|
-
|
|
565
|
-
return parser
|
|
566
|
-
|
|
567
563
|
|
|
568
564
|
def _parse_args_run_superlink() -> argparse.ArgumentParser:
|
|
569
565
|
"""Parse command line arguments for both Driver API and Fleet API."""
|
|
570
566
|
parser = argparse.ArgumentParser(
|
|
571
|
-
description="
|
|
572
|
-
"(meaning, a Driver API and a Fleet API), "
|
|
573
|
-
"that clients will be able to connect to.",
|
|
567
|
+
description="Start a Flower SuperLink",
|
|
574
568
|
)
|
|
575
569
|
|
|
576
570
|
_add_args_common(parser=parser)
|
|
@@ -589,13 +583,23 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
589
583
|
"Use this flag only if you understand the risks.",
|
|
590
584
|
)
|
|
591
585
|
parser.add_argument(
|
|
592
|
-
"--
|
|
593
|
-
|
|
594
|
-
|
|
586
|
+
"--ssl-certfile",
|
|
587
|
+
help="Fleet API server SSL certificate file (as a path str) "
|
|
588
|
+
"to create a secure connection.",
|
|
589
|
+
type=str,
|
|
590
|
+
default=None,
|
|
591
|
+
)
|
|
592
|
+
parser.add_argument(
|
|
593
|
+
"--ssl-keyfile",
|
|
594
|
+
help="Fleet API server SSL private key file (as a path str) "
|
|
595
|
+
"to create a secure connection.",
|
|
596
|
+
type=str,
|
|
597
|
+
)
|
|
598
|
+
parser.add_argument(
|
|
599
|
+
"--ssl-ca-certfile",
|
|
600
|
+
help="Fleet API server SSL CA certificate file (as a path str) "
|
|
601
|
+
"to create a secure connection.",
|
|
595
602
|
type=str,
|
|
596
|
-
help="Paths to the CA certificate, server certificate, and server private "
|
|
597
|
-
"key, in that order. Note: The server can only be started without "
|
|
598
|
-
"certificates by enabling the `--insecure` flag.",
|
|
599
603
|
)
|
|
600
604
|
parser.add_argument(
|
|
601
605
|
"--database",
|
|
@@ -606,108 +610,52 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
606
610
|
"Flower will just create a state in memory.",
|
|
607
611
|
default=DATABASE,
|
|
608
612
|
)
|
|
613
|
+
parser.add_argument(
|
|
614
|
+
"--auth-list-public-keys",
|
|
615
|
+
type=str,
|
|
616
|
+
help="A CSV file (as a path str) containing a list of known public "
|
|
617
|
+
"keys to enable authentication.",
|
|
618
|
+
)
|
|
619
|
+
parser.add_argument(
|
|
620
|
+
"--auth-superlink-private-key",
|
|
621
|
+
type=str,
|
|
622
|
+
help="The SuperLink's private key (as a path str) to enable authentication.",
|
|
623
|
+
)
|
|
624
|
+
parser.add_argument(
|
|
625
|
+
"--auth-superlink-public-key",
|
|
626
|
+
type=str,
|
|
627
|
+
help="The SuperLink's public key (as a path str) to enable authentication.",
|
|
628
|
+
)
|
|
609
629
|
|
|
610
630
|
|
|
611
631
|
def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
|
|
612
632
|
parser.add_argument(
|
|
613
633
|
"--driver-api-address",
|
|
614
|
-
help="Driver API (gRPC) server address (IPv4, IPv6, or a domain name)",
|
|
634
|
+
help="Driver API (gRPC) server address (IPv4, IPv6, or a domain name).",
|
|
615
635
|
default=ADDRESS_DRIVER_API,
|
|
616
636
|
)
|
|
617
637
|
|
|
618
638
|
|
|
619
639
|
def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
|
|
620
640
|
# Fleet API transport layer type
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
"--grpc-rere",
|
|
624
|
-
action="store_const",
|
|
625
|
-
dest="fleet_api_type",
|
|
626
|
-
const=TRANSPORT_TYPE_GRPC_RERE,
|
|
641
|
+
parser.add_argument(
|
|
642
|
+
"--fleet-api-type",
|
|
627
643
|
default=TRANSPORT_TYPE_GRPC_RERE,
|
|
628
|
-
help="Start a Fleet API server (gRPC-rere)",
|
|
629
|
-
)
|
|
630
|
-
ex_group.add_argument(
|
|
631
|
-
"--rest",
|
|
632
|
-
action="store_const",
|
|
633
|
-
dest="fleet_api_type",
|
|
634
|
-
const=TRANSPORT_TYPE_REST,
|
|
635
|
-
help="Start a Fleet API server (REST, experimental)",
|
|
636
|
-
)
|
|
637
|
-
|
|
638
|
-
ex_group.add_argument(
|
|
639
|
-
"--vce",
|
|
640
|
-
action="store_const",
|
|
641
|
-
dest="fleet_api_type",
|
|
642
|
-
const=TRANSPORT_TYPE_VCE,
|
|
643
|
-
help="Start a Fleet API server (VirtualClientEngine)",
|
|
644
|
-
)
|
|
645
|
-
|
|
646
|
-
# Fleet API gRPC-rere options
|
|
647
|
-
grpc_rere_group = parser.add_argument_group(
|
|
648
|
-
"Fleet API (gRPC-rere) server options", ""
|
|
649
|
-
)
|
|
650
|
-
grpc_rere_group.add_argument(
|
|
651
|
-
"--grpc-rere-fleet-api-address",
|
|
652
|
-
help="Fleet API (gRPC-rere) server address (IPv4, IPv6, or a domain name)",
|
|
653
|
-
default=ADDRESS_FLEET_API_GRPC_RERE,
|
|
654
|
-
)
|
|
655
|
-
|
|
656
|
-
# Fleet API REST options
|
|
657
|
-
rest_group = parser.add_argument_group("Fleet API (REST) server options", "")
|
|
658
|
-
rest_group.add_argument(
|
|
659
|
-
"--rest-fleet-api-address",
|
|
660
|
-
help="Fleet API (REST) server address (IPv4, IPv6, or a domain name)",
|
|
661
|
-
default=ADDRESS_FLEET_API_REST,
|
|
662
|
-
)
|
|
663
|
-
rest_group.add_argument(
|
|
664
|
-
"--ssl-certfile",
|
|
665
|
-
help="Fleet API (REST) server SSL certificate file (as a path str), "
|
|
666
|
-
"needed for using 'https'.",
|
|
667
|
-
default=None,
|
|
668
|
-
)
|
|
669
|
-
rest_group.add_argument(
|
|
670
|
-
"--ssl-keyfile",
|
|
671
|
-
help="Fleet API (REST) server SSL private key file (as a path str), "
|
|
672
|
-
"needed for using 'https'.",
|
|
673
|
-
default=None,
|
|
674
|
-
)
|
|
675
|
-
rest_group.add_argument(
|
|
676
|
-
"--rest-fleet-api-workers",
|
|
677
|
-
help="Set the number of concurrent workers for the Fleet API REST server.",
|
|
678
|
-
type=int,
|
|
679
|
-
default=1,
|
|
680
|
-
)
|
|
681
|
-
|
|
682
|
-
# Fleet API VCE options
|
|
683
|
-
vce_group = parser.add_argument_group("Fleet API (VCE) server options", "")
|
|
684
|
-
vce_group.add_argument(
|
|
685
|
-
"--client-app",
|
|
686
|
-
help="For example: `client:app` or `project.package.module:wrapper.app`.",
|
|
687
|
-
)
|
|
688
|
-
vce_group.add_argument(
|
|
689
|
-
"--num-supernodes",
|
|
690
|
-
type=int,
|
|
691
|
-
help="Number of simulated SuperNodes.",
|
|
692
|
-
)
|
|
693
|
-
vce_group.add_argument(
|
|
694
|
-
"--backend",
|
|
695
|
-
default="ray",
|
|
696
644
|
type=str,
|
|
697
|
-
|
|
645
|
+
choices=[
|
|
646
|
+
TRANSPORT_TYPE_GRPC_RERE,
|
|
647
|
+
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
648
|
+
TRANSPORT_TYPE_REST,
|
|
649
|
+
],
|
|
650
|
+
help="Start a gRPC-rere or REST (experimental) Fleet API server.",
|
|
698
651
|
)
|
|
699
|
-
|
|
700
|
-
"--
|
|
701
|
-
|
|
702
|
-
default='{"client_resources": {"num_cpus":1, "num_gpus":0.0}, "tensorflow": 0}',
|
|
703
|
-
help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
|
|
704
|
-
"configure a backend. Values supported in <value> are those included by "
|
|
705
|
-
"`flwr.common.typing.ConfigsRecordValues`. ",
|
|
652
|
+
parser.add_argument(
|
|
653
|
+
"--fleet-api-address",
|
|
654
|
+
help="Fleet API server address (IPv4, IPv6, or a domain name).",
|
|
706
655
|
)
|
|
707
656
|
parser.add_argument(
|
|
708
|
-
"--
|
|
709
|
-
default=
|
|
710
|
-
|
|
711
|
-
"
|
|
712
|
-
" Default: current working directory.",
|
|
657
|
+
"--fleet-api-num-workers",
|
|
658
|
+
default=1,
|
|
659
|
+
type=int,
|
|
660
|
+
help="Set the number of concurrent workers for the Fleet API server.",
|
|
713
661
|
)
|