flwr-nightly 1.9.0.dev20240531__py3-none-any.whl → 1.10.0.dev20240619__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +4 -15
- flwr/cli/config_utils.py +64 -7
- flwr/cli/install.py +211 -0
- flwr/cli/new/templates/app/pyproject.hf.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +39 -2
- flwr/cli/utils.py +14 -0
- flwr/client/__init__.py +1 -0
- flwr/client/app.py +153 -103
- flwr/client/client_app.py +1 -1
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +94 -0
- flwr/client/grpc_client/connection.py +5 -1
- flwr/client/grpc_rere_client/client_interceptor.py +1 -1
- flwr/client/grpc_rere_client/connection.py +9 -5
- flwr/client/grpc_rere_client/grpc_adapter.py +133 -0
- flwr/client/mod/__init__.py +4 -4
- flwr/client/rest_client/connection.py +10 -3
- flwr/client/supernode/app.py +155 -31
- flwr/common/__init__.py +12 -12
- flwr/common/config.py +71 -0
- flwr/common/constant.py +15 -0
- flwr/common/object_ref.py +52 -14
- flwr/common/record/__init__.py +1 -1
- flwr/common/telemetry.py +4 -0
- flwr/common/typing.py +9 -0
- flwr/proto/driver_pb2.py +20 -19
- flwr/proto/driver_pb2_grpc.py +35 -0
- flwr/proto/driver_pb2_grpc.pyi +14 -0
- flwr/proto/exec_pb2.py +34 -0
- flwr/proto/exec_pb2.pyi +55 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/proto/fleet_pb2.py +28 -33
- flwr/proto/fleet_pb2.pyi +0 -42
- flwr/proto/fleet_pb2_grpc.py +7 -6
- flwr/proto/fleet_pb2_grpc.pyi +5 -4
- flwr/proto/run_pb2.py +30 -0
- flwr/proto/run_pb2.pyi +52 -0
- flwr/proto/run_pb2_grpc.py +4 -0
- flwr/proto/run_pb2_grpc.pyi +4 -0
- flwr/server/__init__.py +2 -6
- flwr/server/app.py +94 -214
- flwr/server/run_serverapp.py +33 -7
- flwr/server/server_app.py +2 -2
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/superlink/driver/driver_servicer.py +7 -0
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +4 -0
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -6
- flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +3 -1
- flwr/server/superlink/state/in_memory_state.py +8 -5
- flwr/server/superlink/state/sqlite_state.py +6 -3
- flwr/server/superlink/state/state.py +5 -4
- flwr/simulation/__init__.py +4 -1
- flwr/simulation/run_simulation.py +22 -0
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +178 -0
- flwr/superexec/exec_grpc.py +51 -0
- flwr/superexec/exec_servicer.py +65 -0
- flwr/superexec/executor.py +54 -0
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/METADATA +1 -1
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/RECORD +80 -56
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/entry_points.txt +1 -2
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/WHEEL +0 -0
flwr/server/app.py
CHANGED
|
@@ -36,11 +36,12 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
|
|
36
36
|
from flwr.common.address import parse_address
|
|
37
37
|
from flwr.common.constant import (
|
|
38
38
|
MISSING_EXTRA_REST,
|
|
39
|
+
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
39
40
|
TRANSPORT_TYPE_GRPC_RERE,
|
|
40
41
|
TRANSPORT_TYPE_REST,
|
|
41
42
|
)
|
|
42
43
|
from flwr.common.exit_handlers import register_exit_handlers
|
|
43
|
-
from flwr.common.logger import log
|
|
44
|
+
from flwr.common.logger import log
|
|
44
45
|
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
45
46
|
private_key_to_bytes,
|
|
46
47
|
public_key_to_bytes,
|
|
@@ -48,6 +49,7 @@ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
|
|
|
48
49
|
from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
|
|
49
50
|
add_FleetServicer_to_server,
|
|
50
51
|
)
|
|
52
|
+
from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
|
|
51
53
|
|
|
52
54
|
from .client_manager import ClientManager
|
|
53
55
|
from .history import History
|
|
@@ -55,6 +57,7 @@ from .server import Server, init_defaults, run_fl
|
|
|
55
57
|
from .server_config import ServerConfig
|
|
56
58
|
from .strategy import Strategy
|
|
57
59
|
from .superlink.driver.driver_grpc import run_driver_api_grpc
|
|
60
|
+
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
|
|
58
61
|
from .superlink.fleet.grpc_bidi.grpc_server import (
|
|
59
62
|
generic_create_grpc_server,
|
|
60
63
|
start_grpc_server,
|
|
@@ -190,120 +193,6 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals
|
|
|
190
193
|
return hist
|
|
191
194
|
|
|
192
195
|
|
|
193
|
-
def run_driver_api() -> None:
|
|
194
|
-
"""Run Flower server (Driver API)."""
|
|
195
|
-
log(INFO, "Starting Flower server (Driver API)")
|
|
196
|
-
# Running `flower-driver-api` is deprecated
|
|
197
|
-
warn_deprecated_feature("flower-driver-api")
|
|
198
|
-
log(WARN, "Use `flower-superlink` instead")
|
|
199
|
-
event(EventType.RUN_DRIVER_API_ENTER)
|
|
200
|
-
args = _parse_args_run_driver_api().parse_args()
|
|
201
|
-
|
|
202
|
-
# Parse IP address
|
|
203
|
-
parsed_address = parse_address(args.driver_api_address)
|
|
204
|
-
if not parsed_address:
|
|
205
|
-
sys.exit(f"Driver IP address ({args.driver_api_address}) cannot be parsed.")
|
|
206
|
-
host, port, is_v6 = parsed_address
|
|
207
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
208
|
-
|
|
209
|
-
# Obtain certificates
|
|
210
|
-
certificates = _try_obtain_certificates(args)
|
|
211
|
-
|
|
212
|
-
# Initialize StateFactory
|
|
213
|
-
state_factory = StateFactory(args.database)
|
|
214
|
-
|
|
215
|
-
# Start server
|
|
216
|
-
grpc_server: grpc.Server = run_driver_api_grpc(
|
|
217
|
-
address=address,
|
|
218
|
-
state_factory=state_factory,
|
|
219
|
-
certificates=certificates,
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
# Graceful shutdown
|
|
223
|
-
register_exit_handlers(
|
|
224
|
-
event_type=EventType.RUN_DRIVER_API_LEAVE,
|
|
225
|
-
grpc_servers=[grpc_server],
|
|
226
|
-
bckg_threads=[],
|
|
227
|
-
)
|
|
228
|
-
|
|
229
|
-
# Block
|
|
230
|
-
grpc_server.wait_for_termination()
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
def run_fleet_api() -> None:
|
|
234
|
-
"""Run Flower server (Fleet API)."""
|
|
235
|
-
log(INFO, "Starting Flower server (Fleet API)")
|
|
236
|
-
# Running `flower-fleet-api` is deprecated
|
|
237
|
-
warn_deprecated_feature("flower-fleet-api")
|
|
238
|
-
log(WARN, "Use `flower-superlink` instead")
|
|
239
|
-
event(EventType.RUN_FLEET_API_ENTER)
|
|
240
|
-
args = _parse_args_run_fleet_api().parse_args()
|
|
241
|
-
|
|
242
|
-
# Obtain certificates
|
|
243
|
-
certificates = _try_obtain_certificates(args)
|
|
244
|
-
|
|
245
|
-
# Initialize StateFactory
|
|
246
|
-
state_factory = StateFactory(args.database)
|
|
247
|
-
|
|
248
|
-
grpc_servers = []
|
|
249
|
-
bckg_threads = []
|
|
250
|
-
|
|
251
|
-
# Start Fleet API
|
|
252
|
-
if args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
253
|
-
if (
|
|
254
|
-
importlib.util.find_spec("requests")
|
|
255
|
-
and importlib.util.find_spec("starlette")
|
|
256
|
-
and importlib.util.find_spec("uvicorn")
|
|
257
|
-
) is None:
|
|
258
|
-
sys.exit(MISSING_EXTRA_REST)
|
|
259
|
-
address_arg = args.rest_fleet_api_address
|
|
260
|
-
parsed_address = parse_address(address_arg)
|
|
261
|
-
if not parsed_address:
|
|
262
|
-
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
263
|
-
host, port, _ = parsed_address
|
|
264
|
-
fleet_thread = threading.Thread(
|
|
265
|
-
target=_run_fleet_api_rest,
|
|
266
|
-
args=(
|
|
267
|
-
host,
|
|
268
|
-
port,
|
|
269
|
-
args.ssl_keyfile,
|
|
270
|
-
args.ssl_certfile,
|
|
271
|
-
state_factory,
|
|
272
|
-
args.rest_fleet_api_workers,
|
|
273
|
-
),
|
|
274
|
-
)
|
|
275
|
-
fleet_thread.start()
|
|
276
|
-
bckg_threads.append(fleet_thread)
|
|
277
|
-
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
278
|
-
address_arg = args.grpc_rere_fleet_api_address
|
|
279
|
-
parsed_address = parse_address(address_arg)
|
|
280
|
-
if not parsed_address:
|
|
281
|
-
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
282
|
-
host, port, is_v6 = parsed_address
|
|
283
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
284
|
-
fleet_server = _run_fleet_api_grpc_rere(
|
|
285
|
-
address=address,
|
|
286
|
-
state_factory=state_factory,
|
|
287
|
-
certificates=certificates,
|
|
288
|
-
)
|
|
289
|
-
grpc_servers.append(fleet_server)
|
|
290
|
-
else:
|
|
291
|
-
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
|
292
|
-
|
|
293
|
-
# Graceful shutdown
|
|
294
|
-
register_exit_handlers(
|
|
295
|
-
event_type=EventType.RUN_FLEET_API_LEAVE,
|
|
296
|
-
grpc_servers=grpc_servers,
|
|
297
|
-
bckg_threads=bckg_threads,
|
|
298
|
-
)
|
|
299
|
-
|
|
300
|
-
# Block
|
|
301
|
-
if len(grpc_servers) > 0:
|
|
302
|
-
grpc_servers[0].wait_for_termination()
|
|
303
|
-
elif len(bckg_threads) > 0:
|
|
304
|
-
bckg_threads[0].join()
|
|
305
|
-
|
|
306
|
-
|
|
307
196
|
# pylint: disable=too-many-branches, too-many-locals, too-many-statements
|
|
308
197
|
def run_superlink() -> None:
|
|
309
198
|
"""Run Flower SuperLink (Driver API and Fleet API)."""
|
|
@@ -314,11 +203,7 @@ def run_superlink() -> None:
|
|
|
314
203
|
args = _parse_args_run_superlink().parse_args()
|
|
315
204
|
|
|
316
205
|
# Parse IP address
|
|
317
|
-
|
|
318
|
-
if not parsed_address:
|
|
319
|
-
sys.exit(f"Driver IP address ({args.driver_api_address}) cannot be parsed.")
|
|
320
|
-
host, port, is_v6 = parsed_address
|
|
321
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
206
|
+
driver_address, _, _ = _format_address(args.driver_api_address)
|
|
322
207
|
|
|
323
208
|
# Obtain certificates
|
|
324
209
|
certificates = _try_obtain_certificates(args)
|
|
@@ -328,13 +213,35 @@ def run_superlink() -> None:
|
|
|
328
213
|
|
|
329
214
|
# Start Driver API
|
|
330
215
|
driver_server: grpc.Server = run_driver_api_grpc(
|
|
331
|
-
address=
|
|
216
|
+
address=driver_address,
|
|
332
217
|
state_factory=state_factory,
|
|
333
218
|
certificates=certificates,
|
|
334
219
|
)
|
|
335
220
|
|
|
336
221
|
grpc_servers = [driver_server]
|
|
337
222
|
bckg_threads = []
|
|
223
|
+
if not args.fleet_api_address:
|
|
224
|
+
if args.fleet_api_type in [
|
|
225
|
+
TRANSPORT_TYPE_GRPC_RERE,
|
|
226
|
+
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
227
|
+
]:
|
|
228
|
+
args.fleet_api_address = ADDRESS_FLEET_API_GRPC_RERE
|
|
229
|
+
elif args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
230
|
+
args.fleet_api_address = ADDRESS_FLEET_API_REST
|
|
231
|
+
|
|
232
|
+
fleet_address, host, port = _format_address(args.fleet_api_address)
|
|
233
|
+
|
|
234
|
+
num_workers = args.fleet_api_num_workers
|
|
235
|
+
if num_workers != 1:
|
|
236
|
+
log(
|
|
237
|
+
WARN,
|
|
238
|
+
"The Fleet API currently supports only 1 worker. "
|
|
239
|
+
"You have specified %d workers. "
|
|
240
|
+
"Support for multiple workers will be added in future releases. "
|
|
241
|
+
"Proceeding with a single worker.",
|
|
242
|
+
args.fleet_api_num_workers,
|
|
243
|
+
)
|
|
244
|
+
num_workers = 1
|
|
338
245
|
|
|
339
246
|
# Start Fleet API
|
|
340
247
|
if args.fleet_api_type == TRANSPORT_TYPE_REST:
|
|
@@ -344,14 +251,11 @@ def run_superlink() -> None:
|
|
|
344
251
|
and importlib.util.find_spec("uvicorn")
|
|
345
252
|
) is None:
|
|
346
253
|
sys.exit(MISSING_EXTRA_REST)
|
|
347
|
-
|
|
348
|
-
parsed_address = parse_address(address_arg)
|
|
254
|
+
|
|
349
255
|
_, ssl_certfile, ssl_keyfile = (
|
|
350
256
|
certificates if certificates is not None else (None, None, None)
|
|
351
257
|
)
|
|
352
|
-
|
|
353
|
-
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
354
|
-
host, port, _ = parsed_address
|
|
258
|
+
|
|
355
259
|
fleet_thread = threading.Thread(
|
|
356
260
|
target=_run_fleet_api_rest,
|
|
357
261
|
args=(
|
|
@@ -360,19 +264,12 @@ def run_superlink() -> None:
|
|
|
360
264
|
ssl_keyfile,
|
|
361
265
|
ssl_certfile,
|
|
362
266
|
state_factory,
|
|
363
|
-
|
|
267
|
+
num_workers,
|
|
364
268
|
),
|
|
365
269
|
)
|
|
366
270
|
fleet_thread.start()
|
|
367
271
|
bckg_threads.append(fleet_thread)
|
|
368
272
|
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
|
|
369
|
-
address_arg = args.grpc_rere_fleet_api_address
|
|
370
|
-
parsed_address = parse_address(address_arg)
|
|
371
|
-
if not parsed_address:
|
|
372
|
-
sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
|
|
373
|
-
host, port, is_v6 = parsed_address
|
|
374
|
-
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
375
|
-
|
|
376
273
|
maybe_keys = _try_setup_client_authentication(args, certificates)
|
|
377
274
|
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
|
|
378
275
|
if maybe_keys is not None:
|
|
@@ -395,12 +292,19 @@ def run_superlink() -> None:
|
|
|
395
292
|
interceptors = [AuthenticateServerInterceptor(state)]
|
|
396
293
|
|
|
397
294
|
fleet_server = _run_fleet_api_grpc_rere(
|
|
398
|
-
address=
|
|
295
|
+
address=fleet_address,
|
|
399
296
|
state_factory=state_factory,
|
|
400
297
|
certificates=certificates,
|
|
401
298
|
interceptors=interceptors,
|
|
402
299
|
)
|
|
403
300
|
grpc_servers.append(fleet_server)
|
|
301
|
+
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER:
|
|
302
|
+
fleet_server = _run_fleet_api_grpc_adapter(
|
|
303
|
+
address=fleet_address,
|
|
304
|
+
state_factory=state_factory,
|
|
305
|
+
certificates=certificates,
|
|
306
|
+
)
|
|
307
|
+
grpc_servers.append(fleet_server)
|
|
404
308
|
else:
|
|
405
309
|
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
|
|
406
310
|
|
|
@@ -420,6 +324,16 @@ def run_superlink() -> None:
|
|
|
420
324
|
driver_server.wait_for_termination(timeout=1)
|
|
421
325
|
|
|
422
326
|
|
|
327
|
+
def _format_address(address: str) -> Tuple[str, str, int]:
|
|
328
|
+
parsed_address = parse_address(address)
|
|
329
|
+
if not parsed_address:
|
|
330
|
+
sys.exit(
|
|
331
|
+
f"Address ({address}) cannot be parsed (expected: URL or IPv4 or IPv6)."
|
|
332
|
+
)
|
|
333
|
+
host, port, is_v6 = parsed_address
|
|
334
|
+
return (f"[{host}]:{port}" if is_v6 else f"{host}:{port}", host, port)
|
|
335
|
+
|
|
336
|
+
|
|
423
337
|
def _try_setup_client_authentication(
|
|
424
338
|
args: argparse.Namespace,
|
|
425
339
|
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
@@ -517,7 +431,7 @@ def _try_obtain_certificates(
|
|
|
517
431
|
log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
|
|
518
432
|
return None
|
|
519
433
|
# Check if certificates are provided
|
|
520
|
-
if args.fleet_api_type
|
|
434
|
+
if args.fleet_api_type in [TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_GRPC_ADAPTER]:
|
|
521
435
|
if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
|
|
522
436
|
if not isfile(args.ssl_ca_certfile):
|
|
523
437
|
sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
|
|
@@ -589,6 +503,30 @@ def _run_fleet_api_grpc_rere(
|
|
|
589
503
|
return fleet_grpc_server
|
|
590
504
|
|
|
591
505
|
|
|
506
|
+
def _run_fleet_api_grpc_adapter(
|
|
507
|
+
address: str,
|
|
508
|
+
state_factory: StateFactory,
|
|
509
|
+
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
510
|
+
) -> grpc.Server:
|
|
511
|
+
"""Run Fleet API (GrpcAdapter)."""
|
|
512
|
+
# Create Fleet API gRPC server
|
|
513
|
+
fleet_servicer = GrpcAdapterServicer(
|
|
514
|
+
state_factory=state_factory,
|
|
515
|
+
)
|
|
516
|
+
fleet_add_servicer_to_server_fn = add_GrpcAdapterServicer_to_server
|
|
517
|
+
fleet_grpc_server = generic_create_grpc_server(
|
|
518
|
+
servicer_and_add_fn=(fleet_servicer, fleet_add_servicer_to_server_fn),
|
|
519
|
+
server_address=address,
|
|
520
|
+
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
521
|
+
certificates=certificates,
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
log(INFO, "Flower ECE: Starting Fleet API (GrpcAdapter) on %s", address)
|
|
525
|
+
fleet_grpc_server.start()
|
|
526
|
+
|
|
527
|
+
return fleet_grpc_server
|
|
528
|
+
|
|
529
|
+
|
|
592
530
|
# pylint: disable=import-outside-toplevel,too-many-arguments
|
|
593
531
|
def _run_fleet_api_rest(
|
|
594
532
|
host: str,
|
|
@@ -596,7 +534,7 @@ def _run_fleet_api_rest(
|
|
|
596
534
|
ssl_keyfile: Optional[str],
|
|
597
535
|
ssl_certfile: Optional[str],
|
|
598
536
|
state_factory: StateFactory,
|
|
599
|
-
|
|
537
|
+
num_workers: int,
|
|
600
538
|
) -> None:
|
|
601
539
|
"""Run Driver API (REST-based)."""
|
|
602
540
|
try:
|
|
@@ -605,12 +543,7 @@ def _run_fleet_api_rest(
|
|
|
605
543
|
from flwr.server.superlink.fleet.rest_rere.rest_api import app as fast_api_app
|
|
606
544
|
except ModuleNotFoundError:
|
|
607
545
|
sys.exit(MISSING_EXTRA_REST)
|
|
608
|
-
|
|
609
|
-
raise ValueError(
|
|
610
|
-
f"The supported number of workers for the Fleet API (REST server) is "
|
|
611
|
-
f"1. Instead given {workers}. The functionality of >1 workers will be "
|
|
612
|
-
f"added in the future releases."
|
|
613
|
-
)
|
|
546
|
+
|
|
614
547
|
log(INFO, "Starting Flower REST server")
|
|
615
548
|
|
|
616
549
|
# See: https://www.starlette.io/applications/#accessing-the-app-instance
|
|
@@ -624,44 +557,10 @@ def _run_fleet_api_rest(
|
|
|
624
557
|
access_log=True,
|
|
625
558
|
ssl_keyfile=ssl_keyfile,
|
|
626
559
|
ssl_certfile=ssl_certfile,
|
|
627
|
-
workers=
|
|
560
|
+
workers=num_workers,
|
|
628
561
|
)
|
|
629
562
|
|
|
630
563
|
|
|
631
|
-
def _parse_args_run_driver_api() -> argparse.ArgumentParser:
|
|
632
|
-
"""Parse command line arguments for Driver API."""
|
|
633
|
-
parser = argparse.ArgumentParser(
|
|
634
|
-
description="Start a Flower Driver API server. "
|
|
635
|
-
"This server will be responsible for "
|
|
636
|
-
"receiving TaskIns from the Driver script and "
|
|
637
|
-
"sending them to the Fleet API. Once the client nodes "
|
|
638
|
-
"are done, they will send the TaskRes back to this Driver API server (through"
|
|
639
|
-
" the Fleet API) which will then send them back to the Driver script.",
|
|
640
|
-
)
|
|
641
|
-
|
|
642
|
-
_add_args_common(parser=parser)
|
|
643
|
-
_add_args_driver_api(parser=parser)
|
|
644
|
-
|
|
645
|
-
return parser
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
def _parse_args_run_fleet_api() -> argparse.ArgumentParser:
|
|
649
|
-
"""Parse command line arguments for Fleet API."""
|
|
650
|
-
parser = argparse.ArgumentParser(
|
|
651
|
-
description="Start a Flower Fleet API server."
|
|
652
|
-
"This server will be responsible for "
|
|
653
|
-
"sending TaskIns (received from the Driver API) to the client nodes "
|
|
654
|
-
"and of receiving TaskRes sent back from those same client nodes once "
|
|
655
|
-
"they are done. Then, this Fleet API server can send those "
|
|
656
|
-
"TaskRes back to the Driver API.",
|
|
657
|
-
)
|
|
658
|
-
|
|
659
|
-
_add_args_common(parser=parser)
|
|
660
|
-
_add_args_fleet_api(parser=parser)
|
|
661
|
-
|
|
662
|
-
return parser
|
|
663
|
-
|
|
664
|
-
|
|
665
564
|
def _parse_args_run_superlink() -> argparse.ArgumentParser:
|
|
666
565
|
"""Parse command line arguments for both Driver API and Fleet API."""
|
|
667
566
|
parser = argparse.ArgumentParser(
|
|
@@ -732,50 +631,31 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
|
|
|
732
631
|
def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
|
|
733
632
|
parser.add_argument(
|
|
734
633
|
"--driver-api-address",
|
|
735
|
-
help="Driver API (gRPC) server address (IPv4, IPv6, or a domain name)",
|
|
634
|
+
help="Driver API (gRPC) server address (IPv4, IPv6, or a domain name).",
|
|
736
635
|
default=ADDRESS_DRIVER_API,
|
|
737
636
|
)
|
|
738
637
|
|
|
739
638
|
|
|
740
639
|
def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
|
|
741
640
|
# Fleet API transport layer type
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
"--grpc-rere",
|
|
745
|
-
action="store_const",
|
|
746
|
-
dest="fleet_api_type",
|
|
747
|
-
const=TRANSPORT_TYPE_GRPC_RERE,
|
|
641
|
+
parser.add_argument(
|
|
642
|
+
"--fleet-api-type",
|
|
748
643
|
default=TRANSPORT_TYPE_GRPC_RERE,
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
help="Start a Fleet API server (REST, experimental)",
|
|
757
|
-
)
|
|
758
|
-
|
|
759
|
-
# Fleet API gRPC-rere options
|
|
760
|
-
grpc_rere_group = parser.add_argument_group(
|
|
761
|
-
"Fleet API (gRPC-rere) server options", ""
|
|
762
|
-
)
|
|
763
|
-
grpc_rere_group.add_argument(
|
|
764
|
-
"--grpc-rere-fleet-api-address",
|
|
765
|
-
help="Fleet API (gRPC-rere) server address (IPv4, IPv6, or a domain name)",
|
|
766
|
-
default=ADDRESS_FLEET_API_GRPC_RERE,
|
|
644
|
+
type=str,
|
|
645
|
+
choices=[
|
|
646
|
+
TRANSPORT_TYPE_GRPC_RERE,
|
|
647
|
+
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
648
|
+
TRANSPORT_TYPE_REST,
|
|
649
|
+
],
|
|
650
|
+
help="Start a gRPC-rere or REST (experimental) Fleet API server.",
|
|
767
651
|
)
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
rest_group.add_argument(
|
|
772
|
-
"--rest-fleet-api-address",
|
|
773
|
-
help="Fleet API (REST) server address (IPv4, IPv6, or a domain name)",
|
|
774
|
-
default=ADDRESS_FLEET_API_REST,
|
|
652
|
+
parser.add_argument(
|
|
653
|
+
"--fleet-api-address",
|
|
654
|
+
help="Fleet API server address (IPv4, IPv6, or a domain name).",
|
|
775
655
|
)
|
|
776
|
-
|
|
777
|
-
"--
|
|
778
|
-
help="Set the number of concurrent workers for the Fleet API REST server.",
|
|
779
|
-
type=int,
|
|
656
|
+
parser.add_argument(
|
|
657
|
+
"--fleet-api-num-workers",
|
|
780
658
|
default=1,
|
|
659
|
+
type=int,
|
|
660
|
+
help="Set the number of concurrent workers for the Fleet API server.",
|
|
781
661
|
)
|
flwr/server/run_serverapp.py
CHANGED
|
@@ -22,12 +22,14 @@ from pathlib import Path
|
|
|
22
22
|
from typing import Optional
|
|
23
23
|
|
|
24
24
|
from flwr.common import Context, EventType, RecordSet, event
|
|
25
|
-
from flwr.common.logger import log, update_console_handler
|
|
25
|
+
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
|
|
26
26
|
from flwr.common.object_ref import load_app
|
|
27
27
|
|
|
28
28
|
from .driver import Driver, GrpcDriver
|
|
29
29
|
from .server_app import LoadServerAppError, ServerApp
|
|
30
30
|
|
|
31
|
+
ADDRESS_DRIVER_API = "0.0.0.0:9091"
|
|
32
|
+
|
|
31
33
|
|
|
32
34
|
def run(
|
|
33
35
|
driver: Driver,
|
|
@@ -43,12 +45,14 @@ def run(
|
|
|
43
45
|
)
|
|
44
46
|
|
|
45
47
|
if server_app_dir is not None:
|
|
46
|
-
sys.path.insert(0, server_app_dir)
|
|
48
|
+
sys.path.insert(0, str(Path(server_app_dir).absolute()))
|
|
47
49
|
|
|
48
50
|
# Load ServerApp if needed
|
|
49
51
|
def _load() -> ServerApp:
|
|
50
52
|
if server_app_attr:
|
|
51
|
-
server_app: ServerApp = load_app(
|
|
53
|
+
server_app: ServerApp = load_app(
|
|
54
|
+
server_app_attr, LoadServerAppError, server_app_dir
|
|
55
|
+
)
|
|
52
56
|
|
|
53
57
|
if not isinstance(server_app, ServerApp):
|
|
54
58
|
raise LoadServerAppError(
|
|
@@ -76,6 +80,23 @@ def run_server_app() -> None:
|
|
|
76
80
|
|
|
77
81
|
args = _parse_args_run_server_app().parse_args()
|
|
78
82
|
|
|
83
|
+
if args.server != ADDRESS_DRIVER_API:
|
|
84
|
+
warn = "Passing flag --server is deprecated. Use --superlink instead."
|
|
85
|
+
warn_deprecated_feature(warn)
|
|
86
|
+
|
|
87
|
+
if args.superlink != ADDRESS_DRIVER_API:
|
|
88
|
+
# if `--superlink` also passed, then
|
|
89
|
+
# warn user that this argument overrides what was passed with `--server`
|
|
90
|
+
log(
|
|
91
|
+
WARN,
|
|
92
|
+
"Both `--server` and `--superlink` were passed. "
|
|
93
|
+
"`--server` will be ignored. Connecting to the Superlink Driver API "
|
|
94
|
+
"at %s.",
|
|
95
|
+
args.superlink,
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
args.superlink = args.server
|
|
99
|
+
|
|
79
100
|
update_console_handler(
|
|
80
101
|
level=DEBUG if args.verbose else INFO,
|
|
81
102
|
timestamps=args.verbose,
|
|
@@ -95,7 +116,7 @@ def run_server_app() -> None:
|
|
|
95
116
|
WARN,
|
|
96
117
|
"Option `--insecure` was set. "
|
|
97
118
|
"Starting insecure HTTP client connected to %s.",
|
|
98
|
-
args.
|
|
119
|
+
args.superlink,
|
|
99
120
|
)
|
|
100
121
|
root_certificates = None
|
|
101
122
|
else:
|
|
@@ -109,7 +130,7 @@ def run_server_app() -> None:
|
|
|
109
130
|
DEBUG,
|
|
110
131
|
"Starting secure HTTPS client connected to %s "
|
|
111
132
|
"with the following certificates: %s.",
|
|
112
|
-
args.
|
|
133
|
+
args.superlink,
|
|
113
134
|
cert_path,
|
|
114
135
|
)
|
|
115
136
|
|
|
@@ -130,7 +151,7 @@ def run_server_app() -> None:
|
|
|
130
151
|
|
|
131
152
|
# Initialize GrpcDriver
|
|
132
153
|
driver = GrpcDriver(
|
|
133
|
-
driver_service_address=args.
|
|
154
|
+
driver_service_address=args.superlink,
|
|
134
155
|
root_certificates=root_certificates,
|
|
135
156
|
fab_id=args.fab_id,
|
|
136
157
|
fab_version=args.fab_version,
|
|
@@ -175,9 +196,14 @@ def _parse_args_run_server_app() -> argparse.ArgumentParser:
|
|
|
175
196
|
)
|
|
176
197
|
parser.add_argument(
|
|
177
198
|
"--server",
|
|
178
|
-
default=
|
|
199
|
+
default=ADDRESS_DRIVER_API,
|
|
179
200
|
help="Server address",
|
|
180
201
|
)
|
|
202
|
+
parser.add_argument(
|
|
203
|
+
"--superlink",
|
|
204
|
+
default=ADDRESS_DRIVER_API,
|
|
205
|
+
help="SuperLink Driver API (gRPC-rere) address (IPv4, IPv6, or a domain name)",
|
|
206
|
+
)
|
|
181
207
|
parser.add_argument(
|
|
182
208
|
"--dir",
|
|
183
209
|
default="",
|
flwr/server/server_app.py
CHANGED
|
@@ -39,7 +39,7 @@ class ServerApp:
|
|
|
39
39
|
>>> server_config = ServerConfig(num_rounds=3)
|
|
40
40
|
>>> strategy = FedAvg()
|
|
41
41
|
>>>
|
|
42
|
-
>>> app = ServerApp(
|
|
42
|
+
>>> app = ServerApp(
|
|
43
43
|
>>> server_config=server_config,
|
|
44
44
|
>>> strategy=strategy,
|
|
45
45
|
>>> )
|
|
@@ -106,7 +106,7 @@ class ServerApp:
|
|
|
106
106
|
>>> server_config = ServerConfig(num_rounds=3)
|
|
107
107
|
>>> strategy = FedAvg()
|
|
108
108
|
>>>
|
|
109
|
-
>>> app = ServerApp(
|
|
109
|
+
>>> app = ServerApp(
|
|
110
110
|
>>> server_config=server_config,
|
|
111
111
|
>>> strategy=strategy,
|
|
112
112
|
>>> )
|
flwr/server/strategy/__init__.py
CHANGED
|
@@ -53,9 +53,10 @@ __all__ = [
|
|
|
53
53
|
"DPFedAvgAdaptive",
|
|
54
54
|
"DPFedAvgFixed",
|
|
55
55
|
"DifferentialPrivacyClientSideAdaptiveClipping",
|
|
56
|
-
"DifferentialPrivacyServerSideAdaptiveClipping",
|
|
57
56
|
"DifferentialPrivacyClientSideFixedClipping",
|
|
57
|
+
"DifferentialPrivacyServerSideAdaptiveClipping",
|
|
58
58
|
"DifferentialPrivacyServerSideFixedClipping",
|
|
59
|
+
"FaultTolerantFedAvg",
|
|
59
60
|
"FedAdagrad",
|
|
60
61
|
"FedAdam",
|
|
61
62
|
"FedAvg",
|
|
@@ -69,7 +70,6 @@ __all__ = [
|
|
|
69
70
|
"FedXgbCyclic",
|
|
70
71
|
"FedXgbNnAvg",
|
|
71
72
|
"FedYogi",
|
|
72
|
-
"FaultTolerantFedAvg",
|
|
73
73
|
"Krum",
|
|
74
74
|
"QFedAvg",
|
|
75
75
|
"Strategy",
|
|
@@ -35,6 +35,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
|
|
|
35
35
|
PushTaskInsResponse,
|
|
36
36
|
)
|
|
37
37
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
38
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
38
39
|
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
|
|
39
40
|
from flwr.server.superlink.state import State, StateFactory
|
|
40
41
|
from flwr.server.utils.validator import validate_task_ins_or_res
|
|
@@ -129,6 +130,12 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
|
|
|
129
130
|
context.set_code(grpc.StatusCode.OK)
|
|
130
131
|
return PullTaskResResponse(task_res_list=task_res_list)
|
|
131
132
|
|
|
133
|
+
def GetRun(
|
|
134
|
+
self, request: GetRunRequest, context: grpc.ServicerContext
|
|
135
|
+
) -> GetRunResponse:
|
|
136
|
+
"""Get run information."""
|
|
137
|
+
raise NotImplementedError
|
|
138
|
+
|
|
132
139
|
|
|
133
140
|
def _raise_if(validation_error: bool, detail: str) -> None:
|
|
134
141
|
if validation_error:
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Server-side part of the GrpcAdapter transport layer."""
|