flwr-nightly 1.9.0.dev20240520__py3-none-any.whl → 1.10.0.dev20240612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (53) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +4 -19
  3. flwr/cli/config_utils.py +12 -27
  4. flwr/cli/install.py +196 -0
  5. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +7 -1
  6. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +7 -1
  8. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +7 -1
  9. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +7 -1
  10. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +7 -1
  11. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +7 -1
  12. flwr/cli/run/run.py +20 -4
  13. flwr/cli/utils.py +14 -0
  14. flwr/client/__init__.py +1 -0
  15. flwr/client/app.py +135 -97
  16. flwr/client/client_app.py +1 -1
  17. flwr/client/grpc_rere_client/client_interceptor.py +1 -1
  18. flwr/client/grpc_rere_client/connection.py +6 -6
  19. flwr/client/mod/__init__.py +1 -1
  20. flwr/client/rest_client/connection.py +1 -2
  21. flwr/client/supernode/app.py +70 -28
  22. flwr/common/object_ref.py +13 -9
  23. flwr/common/recordset_compat.py +8 -1
  24. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +0 -15
  25. flwr/proto/driver_pb2.py +20 -19
  26. flwr/proto/driver_pb2_grpc.py +35 -0
  27. flwr/proto/driver_pb2_grpc.pyi +14 -0
  28. flwr/proto/fleet_pb2.py +28 -33
  29. flwr/proto/fleet_pb2.pyi +0 -42
  30. flwr/proto/fleet_pb2_grpc.py +7 -6
  31. flwr/proto/fleet_pb2_grpc.pyi +5 -4
  32. flwr/proto/grpcadapter_pb2.py +32 -0
  33. flwr/proto/grpcadapter_pb2.pyi +43 -0
  34. flwr/proto/grpcadapter_pb2_grpc.py +66 -0
  35. flwr/proto/grpcadapter_pb2_grpc.pyi +24 -0
  36. flwr/proto/run_pb2.py +30 -0
  37. flwr/proto/run_pb2.pyi +52 -0
  38. flwr/proto/run_pb2_grpc.py +4 -0
  39. flwr/proto/run_pb2_grpc.pyi +4 -0
  40. flwr/server/__init__.py +0 -4
  41. flwr/server/app.py +190 -395
  42. flwr/server/run_serverapp.py +29 -5
  43. flwr/server/server_app.py +2 -2
  44. flwr/server/superlink/driver/driver_servicer.py +7 -0
  45. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -2
  46. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -2
  47. flwr/server/superlink/fleet/message_handler/message_handler.py +5 -3
  48. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
  49. {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/METADATA +4 -3
  50. {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/RECORD +53 -44
  51. {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/entry_points.txt +0 -2
  52. {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/LICENSE +0 -0
  53. {flwr_nightly-1.9.0.dev20240520.dist-info → flwr_nightly-1.10.0.dev20240612.dist-info}/WHEEL +0 -0
flwr/server/app.py CHANGED
@@ -15,17 +15,17 @@
15
15
  """Flower server app."""
16
16
 
17
17
  import argparse
18
- import asyncio
19
18
  import csv
20
19
  import importlib.util
21
20
  import sys
22
21
  import threading
23
- from logging import ERROR, INFO, WARN
22
+ from logging import INFO, WARN
24
23
  from os.path import isfile
25
24
  from pathlib import Path
26
- from typing import List, Optional, Sequence, Set, Tuple
25
+ from typing import Optional, Sequence, Set, Tuple
27
26
 
28
27
  import grpc
28
+ from cryptography.exceptions import UnsupportedAlgorithm
29
29
  from cryptography.hazmat.primitives.asymmetric import ec
30
30
  from cryptography.hazmat.primitives.serialization import (
31
31
  load_ssh_private_key,
@@ -38,14 +38,12 @@ from flwr.common.constant import (
38
38
  MISSING_EXTRA_REST,
39
39
  TRANSPORT_TYPE_GRPC_RERE,
40
40
  TRANSPORT_TYPE_REST,
41
- TRANSPORT_TYPE_VCE,
42
41
  )
43
42
  from flwr.common.exit_handlers import register_exit_handlers
44
- from flwr.common.logger import log, warn_deprecated_feature
43
+ from flwr.common.logger import log
45
44
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
46
45
  private_key_to_bytes,
47
46
  public_key_to_bytes,
48
- ssh_types_to_elliptic_curve,
49
47
  )
50
48
  from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
51
49
  add_FleetServicer_to_server,
@@ -63,7 +61,6 @@ from .superlink.fleet.grpc_bidi.grpc_server import (
63
61
  )
64
62
  from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
65
63
  from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
66
- from .superlink.fleet.vce import start_vce
67
64
  from .superlink.state import StateFactory
68
65
 
69
66
  ADDRESS_DRIVER_API = "0.0.0.0:9091"
@@ -193,120 +190,6 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals
193
190
  return hist
194
191
 
195
192
 
196
- def run_driver_api() -> None:
197
- """Run Flower server (Driver API)."""
198
- log(INFO, "Starting Flower server (Driver API)")
199
- # Running `flower-driver-api` is deprecated
200
- warn_deprecated_feature("flower-driver-api")
201
- log(WARN, "Use `flower-superlink` instead")
202
- event(EventType.RUN_DRIVER_API_ENTER)
203
- args = _parse_args_run_driver_api().parse_args()
204
-
205
- # Parse IP address
206
- parsed_address = parse_address(args.driver_api_address)
207
- if not parsed_address:
208
- sys.exit(f"Driver IP address ({args.driver_api_address}) cannot be parsed.")
209
- host, port, is_v6 = parsed_address
210
- address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
211
-
212
- # Obtain certificates
213
- certificates = _try_obtain_certificates(args)
214
-
215
- # Initialize StateFactory
216
- state_factory = StateFactory(args.database)
217
-
218
- # Start server
219
- grpc_server: grpc.Server = run_driver_api_grpc(
220
- address=address,
221
- state_factory=state_factory,
222
- certificates=certificates,
223
- )
224
-
225
- # Graceful shutdown
226
- register_exit_handlers(
227
- event_type=EventType.RUN_DRIVER_API_LEAVE,
228
- grpc_servers=[grpc_server],
229
- bckg_threads=[],
230
- )
231
-
232
- # Block
233
- grpc_server.wait_for_termination()
234
-
235
-
236
- def run_fleet_api() -> None:
237
- """Run Flower server (Fleet API)."""
238
- log(INFO, "Starting Flower server (Fleet API)")
239
- # Running `flower-fleet-api` is deprecated
240
- warn_deprecated_feature("flower-fleet-api")
241
- log(WARN, "Use `flower-superlink` instead")
242
- event(EventType.RUN_FLEET_API_ENTER)
243
- args = _parse_args_run_fleet_api().parse_args()
244
-
245
- # Obtain certificates
246
- certificates = _try_obtain_certificates(args)
247
-
248
- # Initialize StateFactory
249
- state_factory = StateFactory(args.database)
250
-
251
- grpc_servers = []
252
- bckg_threads = []
253
-
254
- # Start Fleet API
255
- if args.fleet_api_type == TRANSPORT_TYPE_REST:
256
- if (
257
- importlib.util.find_spec("requests")
258
- and importlib.util.find_spec("starlette")
259
- and importlib.util.find_spec("uvicorn")
260
- ) is None:
261
- sys.exit(MISSING_EXTRA_REST)
262
- address_arg = args.rest_fleet_api_address
263
- parsed_address = parse_address(address_arg)
264
- if not parsed_address:
265
- sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
266
- host, port, _ = parsed_address
267
- fleet_thread = threading.Thread(
268
- target=_run_fleet_api_rest,
269
- args=(
270
- host,
271
- port,
272
- args.ssl_keyfile,
273
- args.ssl_certfile,
274
- state_factory,
275
- args.rest_fleet_api_workers,
276
- ),
277
- )
278
- fleet_thread.start()
279
- bckg_threads.append(fleet_thread)
280
- elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
281
- address_arg = args.grpc_rere_fleet_api_address
282
- parsed_address = parse_address(address_arg)
283
- if not parsed_address:
284
- sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
285
- host, port, is_v6 = parsed_address
286
- address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
287
- fleet_server = _run_fleet_api_grpc_rere(
288
- address=address,
289
- state_factory=state_factory,
290
- certificates=certificates,
291
- )
292
- grpc_servers.append(fleet_server)
293
- else:
294
- raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
295
-
296
- # Graceful shutdown
297
- register_exit_handlers(
298
- event_type=EventType.RUN_FLEET_API_LEAVE,
299
- grpc_servers=grpc_servers,
300
- bckg_threads=bckg_threads,
301
- )
302
-
303
- # Block
304
- if len(grpc_servers) > 0:
305
- grpc_servers[0].wait_for_termination()
306
- elif len(bckg_threads) > 0:
307
- bckg_threads[0].join()
308
-
309
-
310
193
  # pylint: disable=too-many-branches, too-many-locals, too-many-statements
311
194
  def run_superlink() -> None:
312
195
  """Run Flower SuperLink (Driver API and Fleet API)."""
@@ -317,11 +200,15 @@ def run_superlink() -> None:
317
200
  args = _parse_args_run_superlink().parse_args()
318
201
 
319
202
  # Parse IP address
320
- parsed_address = parse_address(args.driver_api_address)
321
- if not parsed_address:
203
+ parsed_driver_address = parse_address(args.driver_api_address)
204
+ if not parsed_driver_address:
322
205
  sys.exit(f"Driver IP address ({args.driver_api_address}) cannot be parsed.")
323
- host, port, is_v6 = parsed_address
324
- address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
206
+ driver_host, driver_port, driver_is_v6 = parsed_driver_address
207
+ driver_address = (
208
+ f"[{driver_host}]:{driver_port}"
209
+ if driver_is_v6
210
+ else f"{driver_host}:{driver_port}"
211
+ )
325
212
 
326
213
  # Obtain certificates
327
214
  certificates = _try_obtain_certificates(args)
@@ -331,13 +218,38 @@ def run_superlink() -> None:
331
218
 
332
219
  # Start Driver API
333
220
  driver_server: grpc.Server = run_driver_api_grpc(
334
- address=address,
221
+ address=driver_address,
335
222
  state_factory=state_factory,
336
223
  certificates=certificates,
337
224
  )
338
225
 
339
226
  grpc_servers = [driver_server]
340
227
  bckg_threads = []
228
+ if not args.fleet_api_address:
229
+ args.fleet_api_address = (
230
+ ADDRESS_FLEET_API_GRPC_RERE
231
+ if args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE
232
+ else ADDRESS_FLEET_API_REST
233
+ )
234
+ parsed_fleet_address = parse_address(args.fleet_api_address)
235
+ if not parsed_fleet_address:
236
+ sys.exit(f"Fleet IP address ({args.fleet_api_address}) cannot be parsed.")
237
+ fleet_host, fleet_port, fleet_is_v6 = parsed_fleet_address
238
+ fleet_address = (
239
+ f"[{fleet_host}]:{fleet_port}" if fleet_is_v6 else f"{fleet_host}:{fleet_port}"
240
+ )
241
+
242
+ num_workers = args.fleet_api_num_workers
243
+ if num_workers != 1:
244
+ log(
245
+ WARN,
246
+ "The Fleet API currently supports only 1 worker. "
247
+ "You have specified %d workers. "
248
+ "Support for multiple workers will be added in future releases. "
249
+ "Proceeding with a single worker.",
250
+ args.fleet_api_num_workers,
251
+ )
252
+ num_workers = 1
341
253
 
342
254
  # Start Fleet API
343
255
  if args.fleet_api_type == TRANSPORT_TYPE_REST:
@@ -347,32 +259,25 @@ def run_superlink() -> None:
347
259
  and importlib.util.find_spec("uvicorn")
348
260
  ) is None:
349
261
  sys.exit(MISSING_EXTRA_REST)
350
- address_arg = args.rest_fleet_api_address
351
- parsed_address = parse_address(address_arg)
352
- if not parsed_address:
353
- sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
354
- host, port, _ = parsed_address
262
+
263
+ _, ssl_certfile, ssl_keyfile = (
264
+ certificates if certificates is not None else (None, None, None)
265
+ )
266
+
355
267
  fleet_thread = threading.Thread(
356
268
  target=_run_fleet_api_rest,
357
269
  args=(
358
- host,
359
- port,
360
- args.ssl_keyfile,
361
- args.ssl_certfile,
270
+ fleet_host,
271
+ fleet_port,
272
+ ssl_keyfile,
273
+ ssl_certfile,
362
274
  state_factory,
363
- args.rest_fleet_api_workers,
275
+ num_workers,
364
276
  ),
365
277
  )
366
278
  fleet_thread.start()
367
279
  bckg_threads.append(fleet_thread)
368
280
  elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
369
- address_arg = args.grpc_rere_fleet_api_address
370
- parsed_address = parse_address(address_arg)
371
- if not parsed_address:
372
- sys.exit(f"Fleet IP address ({address_arg}) cannot be parsed.")
373
- host, port, is_v6 = parsed_address
374
- address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
375
-
376
281
  maybe_keys = _try_setup_client_authentication(args, certificates)
377
282
  interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
378
283
  if maybe_keys is not None:
@@ -395,23 +300,12 @@ def run_superlink() -> None:
395
300
  interceptors = [AuthenticateServerInterceptor(state)]
396
301
 
397
302
  fleet_server = _run_fleet_api_grpc_rere(
398
- address=address,
303
+ address=fleet_address,
399
304
  state_factory=state_factory,
400
305
  certificates=certificates,
401
306
  interceptors=interceptors,
402
307
  )
403
308
  grpc_servers.append(fleet_server)
404
- elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
405
- f_stop = asyncio.Event() # Does nothing
406
- _run_fleet_api_vce(
407
- num_supernodes=args.num_supernodes,
408
- client_app_attr=args.client_app,
409
- backend_name=args.backend,
410
- backend_config_json_stream=args.backend_config,
411
- app_dir=args.app_dir,
412
- state_factory=state_factory,
413
- f_stop=f_stop,
414
- )
415
309
  else:
416
310
  raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
417
311
 
@@ -435,44 +329,69 @@ def _try_setup_client_authentication(
435
329
  args: argparse.Namespace,
436
330
  certificates: Optional[Tuple[bytes, bytes, bytes]],
437
331
  ) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
438
- if not args.require_client_authentication:
332
+ if (
333
+ not args.auth_list_public_keys
334
+ and not args.auth_superlink_private_key
335
+ and not args.auth_superlink_public_key
336
+ ):
439
337
  return None
440
338
 
339
+ if (
340
+ not args.auth_list_public_keys
341
+ or not args.auth_superlink_private_key
342
+ or not args.auth_superlink_public_key
343
+ ):
344
+ sys.exit(
345
+ "Authentication requires providing file paths for "
346
+ "'--auth-list-public-keys', '--auth-superlink-private-key' and "
347
+ "'--auth-superlink-public-key'. Provide all three to enable authentication."
348
+ )
349
+
441
350
  if certificates is None:
442
351
  sys.exit(
443
- "Client authentication only works over secure connections. "
444
- "Please provide certificate paths using '--certificates' when "
445
- "enabling '--require-client-authentication'."
352
+ "Authentication requires secure connections. "
353
+ "Please provide certificate paths to `--ssl-certfile`, "
354
+ "`--ssl-keyfile`, and `—-ssl-ca-certfile` and try again."
446
355
  )
447
356
 
448
- client_keys_file_path = Path(args.require_client_authentication[0])
357
+ client_keys_file_path = Path(args.auth_list_public_keys)
449
358
  if not client_keys_file_path.exists():
450
359
  sys.exit(
451
- "The provided path to the client public keys CSV file does not exist: "
360
+ "The provided path to the known public keys CSV file does not exist: "
452
361
  f"{client_keys_file_path}. "
453
- "Please provide the CSV file path containing known client public keys "
454
- "to '--require-client-authentication'."
362
+ "Please provide the CSV file path containing known public keys "
363
+ "to '--auth-list-public-keys'."
455
364
  )
456
365
 
457
366
  client_public_keys: Set[bytes] = set()
458
- ssh_private_key = load_ssh_private_key(
459
- Path(args.require_client_authentication[1]).read_bytes(),
460
- None,
461
- )
462
- ssh_public_key = load_ssh_public_key(
463
- Path(args.require_client_authentication[2]).read_bytes()
464
- )
465
367
 
466
368
  try:
467
- server_private_key, server_public_key = ssh_types_to_elliptic_curve(
468
- ssh_private_key, ssh_public_key
369
+ ssh_private_key = load_ssh_private_key(
370
+ Path(args.auth_superlink_private_key).read_bytes(),
371
+ None,
372
+ )
373
+ if not isinstance(ssh_private_key, ec.EllipticCurvePrivateKey):
374
+ raise ValueError()
375
+ except (ValueError, UnsupportedAlgorithm):
376
+ sys.exit(
377
+ "Error: Unable to parse the private key file in "
378
+ "'--auth-superlink-private-key'. Authentication requires elliptic "
379
+ "curve private and public key pair. Please ensure that the file "
380
+ "path points to a valid private key file and try again."
381
+ )
382
+
383
+ try:
384
+ ssh_public_key = load_ssh_public_key(
385
+ Path(args.auth_superlink_public_key).read_bytes()
469
386
  )
470
- except TypeError:
387
+ if not isinstance(ssh_public_key, ec.EllipticCurvePublicKey):
388
+ raise ValueError()
389
+ except (ValueError, UnsupportedAlgorithm):
471
390
  sys.exit(
472
- "The file paths provided could not be read as a private and public "
473
- "key pair. Client authentication requires an elliptic curve public and "
474
- "private key pair. Please provide the file paths containing elliptic "
475
- "curve private and public keys to '--require-client-authentication'."
391
+ "Error: Unable to parse the public key file in "
392
+ "'--auth-superlink-public-key'. Authentication requires elliptic "
393
+ "curve private and public key pair. Please ensure that the file "
394
+ "path points to a valid public key file and try again."
476
395
  )
477
396
 
478
397
  with open(client_keys_file_path, newline="", encoding="utf-8") as csvfile:
@@ -484,14 +403,14 @@ def _try_setup_client_authentication(
484
403
  client_public_keys.add(public_key_to_bytes(public_key))
485
404
  else:
486
405
  sys.exit(
487
- "Error: Unable to parse the public keys in the .csv "
488
- "file. Please ensure that the .csv file contains valid "
489
- "SSH public keys and try again."
406
+ "Error: Unable to parse the public keys in the CSV "
407
+ "file. Please ensure that the CSV file path points to a valid "
408
+ "known SSH public keys files and try again."
490
409
  )
491
410
  return (
492
411
  client_public_keys,
493
- server_private_key,
494
- server_public_key,
412
+ ssh_private_key,
413
+ ssh_public_key,
495
414
  )
496
415
 
497
416
 
@@ -501,21 +420,52 @@ def _try_obtain_certificates(
501
420
  # Obtain certificates
502
421
  if args.insecure:
503
422
  log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
504
- certificates = None
423
+ return None
505
424
  # Check if certificates are provided
506
- elif args.certificates:
507
- certificates = (
508
- Path(args.certificates[0]).read_bytes(), # CA certificate
509
- Path(args.certificates[1]).read_bytes(), # server certificate
510
- Path(args.certificates[2]).read_bytes(), # server private key
511
- )
512
- else:
513
- sys.exit(
514
- "Certificates are required unless running in insecure mode. "
515
- "Please provide certificate paths with '--certificates' or run the server "
516
- "in insecure mode using '--insecure' if you understand the risks."
517
- )
518
- return certificates
425
+ if args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
426
+ if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
427
+ if not isfile(args.ssl_ca_certfile):
428
+ sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
429
+ if not isfile(args.ssl_certfile):
430
+ sys.exit("Path argument `--ssl-certfile` does not point to a file.")
431
+ if not isfile(args.ssl_keyfile):
432
+ sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
433
+ certificates = (
434
+ Path(args.ssl_ca_certfile).read_bytes(), # CA certificate
435
+ Path(args.ssl_certfile).read_bytes(), # server certificate
436
+ Path(args.ssl_keyfile).read_bytes(), # server private key
437
+ )
438
+ return certificates
439
+ if args.ssl_certfile or args.ssl_keyfile or args.ssl_ca_certfile:
440
+ sys.exit(
441
+ "You need to provide valid file paths to `--ssl-certfile`, "
442
+ "`--ssl-keyfile`, and `—-ssl-ca-certfile` to create a secure "
443
+ "connection in Fleet API server (gRPC-rere)."
444
+ )
445
+ if args.fleet_api_type == TRANSPORT_TYPE_REST:
446
+ if args.ssl_certfile and args.ssl_keyfile:
447
+ if not isfile(args.ssl_certfile):
448
+ sys.exit("Path argument `--ssl-certfile` does not point to a file.")
449
+ if not isfile(args.ssl_keyfile):
450
+ sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
451
+ certificates = (
452
+ b"",
453
+ Path(args.ssl_certfile).read_bytes(), # server certificate
454
+ Path(args.ssl_keyfile).read_bytes(), # server private key
455
+ )
456
+ return certificates
457
+ if args.ssl_certfile or args.ssl_keyfile:
458
+ sys.exit(
459
+ "You need to provide valid file paths to `--ssl-certfile` "
460
+ "and `--ssl-keyfile` to create a secure connection "
461
+ "in Fleet API server (REST, experimental)."
462
+ )
463
+ sys.exit(
464
+ "Certificates are required unless running in insecure mode. "
465
+ "Please provide certificate paths to `--ssl-certfile`, "
466
+ "`--ssl-keyfile`, and `—-ssl-ca-certfile` or run the server "
467
+ "in insecure mode using '--insecure' if you understand the risks."
468
+ )
519
469
 
520
470
 
521
471
  def _run_fleet_api_grpc_rere(
@@ -544,29 +494,6 @@ def _run_fleet_api_grpc_rere(
544
494
  return fleet_grpc_server
545
495
 
546
496
 
547
- # pylint: disable=too-many-arguments
548
- def _run_fleet_api_vce(
549
- num_supernodes: int,
550
- client_app_attr: str,
551
- backend_name: str,
552
- backend_config_json_stream: str,
553
- app_dir: str,
554
- state_factory: StateFactory,
555
- f_stop: asyncio.Event,
556
- ) -> None:
557
- log(INFO, "Flower VCE: Starting Fleet API (VirtualClientEngine)")
558
-
559
- start_vce(
560
- num_supernodes=num_supernodes,
561
- client_app_attr=client_app_attr,
562
- backend_name=backend_name,
563
- backend_config_json_stream=backend_config_json_stream,
564
- state_factory=state_factory,
565
- app_dir=app_dir,
566
- f_stop=f_stop,
567
- )
568
-
569
-
570
497
  # pylint: disable=import-outside-toplevel,too-many-arguments
571
498
  def _run_fleet_api_rest(
572
499
  host: str,
@@ -574,7 +501,7 @@ def _run_fleet_api_rest(
574
501
  ssl_keyfile: Optional[str],
575
502
  ssl_certfile: Optional[str],
576
503
  state_factory: StateFactory,
577
- workers: int,
504
+ num_workers: int,
578
505
  ) -> None:
579
506
  """Run Driver API (REST-based)."""
580
507
  try:
@@ -583,25 +510,12 @@ def _run_fleet_api_rest(
583
510
  from flwr.server.superlink.fleet.rest_rere.rest_api import app as fast_api_app
584
511
  except ModuleNotFoundError:
585
512
  sys.exit(MISSING_EXTRA_REST)
586
- if workers != 1:
587
- raise ValueError(
588
- f"The supported number of workers for the Fleet API (REST server) is "
589
- f"1. Instead given {workers}. The functionality of >1 workers will be "
590
- f"added in the future releases."
591
- )
513
+
592
514
  log(INFO, "Starting Flower REST server")
593
515
 
594
516
  # See: https://www.starlette.io/applications/#accessing-the-app-instance
595
517
  fast_api_app.state.STATE_FACTORY = state_factory
596
518
 
597
- validation_exceptions = _validate_ssl_files(
598
- ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile
599
- )
600
- if any(validation_exceptions):
601
- # Starting with 3.11 we can use ExceptionGroup but for now
602
- # this seems to be the reasonable approach.
603
- raise ValueError(validation_exceptions)
604
-
605
519
  uvicorn.run(
606
520
  app="flwr.server.superlink.fleet.rest_rere.rest_api:app",
607
521
  port=port,
@@ -610,70 +524,10 @@ def _run_fleet_api_rest(
610
524
  access_log=True,
611
525
  ssl_keyfile=ssl_keyfile,
612
526
  ssl_certfile=ssl_certfile,
613
- workers=workers,
527
+ workers=num_workers,
614
528
  )
615
529
 
616
530
 
617
- def _validate_ssl_files(
618
- ssl_keyfile: Optional[str], ssl_certfile: Optional[str]
619
- ) -> List[ValueError]:
620
- validation_exceptions = []
621
-
622
- if ssl_keyfile is not None and not isfile(ssl_keyfile):
623
- msg = "Path argument `--ssl-keyfile` does not point to a file."
624
- log(ERROR, msg)
625
- validation_exceptions.append(ValueError(msg))
626
-
627
- if ssl_certfile is not None and not isfile(ssl_certfile):
628
- msg = "Path argument `--ssl-certfile` does not point to a file."
629
- log(ERROR, msg)
630
- validation_exceptions.append(ValueError(msg))
631
-
632
- if not bool(ssl_keyfile) == bool(ssl_certfile):
633
- msg = (
634
- "When setting one of `--ssl-keyfile` and "
635
- "`--ssl-certfile`, both have to be used."
636
- )
637
- log(ERROR, msg)
638
- validation_exceptions.append(ValueError(msg))
639
-
640
- return validation_exceptions
641
-
642
-
643
- def _parse_args_run_driver_api() -> argparse.ArgumentParser:
644
- """Parse command line arguments for Driver API."""
645
- parser = argparse.ArgumentParser(
646
- description="Start a Flower Driver API server. "
647
- "This server will be responsible for "
648
- "receiving TaskIns from the Driver script and "
649
- "sending them to the Fleet API. Once the client nodes "
650
- "are done, they will send the TaskRes back to this Driver API server (through"
651
- " the Fleet API) which will then send them back to the Driver script.",
652
- )
653
-
654
- _add_args_common(parser=parser)
655
- _add_args_driver_api(parser=parser)
656
-
657
- return parser
658
-
659
-
660
- def _parse_args_run_fleet_api() -> argparse.ArgumentParser:
661
- """Parse command line arguments for Fleet API."""
662
- parser = argparse.ArgumentParser(
663
- description="Start a Flower Fleet API server."
664
- "This server will be responsible for "
665
- "sending TaskIns (received from the Driver API) to the client nodes "
666
- "and of receiving TaskRes sent back from those same client nodes once "
667
- "they are done. Then, this Fleet API server can send those "
668
- "TaskRes back to the Driver API.",
669
- )
670
-
671
- _add_args_common(parser=parser)
672
- _add_args_fleet_api(parser=parser)
673
-
674
- return parser
675
-
676
-
677
531
  def _parse_args_run_superlink() -> argparse.ArgumentParser:
678
532
  """Parse command line arguments for both Driver API and Fleet API."""
679
533
  parser = argparse.ArgumentParser(
@@ -696,13 +550,23 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
696
550
  "Use this flag only if you understand the risks.",
697
551
  )
698
552
  parser.add_argument(
699
- "--certificates",
700
- nargs=3,
701
- metavar=("CA_CERT", "SERVER_CERT", "PRIVATE_KEY"),
553
+ "--ssl-certfile",
554
+ help="Fleet API server SSL certificate file (as a path str) "
555
+ "to create a secure connection.",
556
+ type=str,
557
+ default=None,
558
+ )
559
+ parser.add_argument(
560
+ "--ssl-keyfile",
561
+ help="Fleet API server SSL private key file (as a path str) "
562
+ "to create a secure connection.",
563
+ type=str,
564
+ )
565
+ parser.add_argument(
566
+ "--ssl-ca-certfile",
567
+ help="Fleet API server SSL CA certificate file (as a path str) "
568
+ "to create a secure connection.",
702
569
  type=str,
703
- help="Paths to the CA certificate, server certificate, and server private "
704
- "key, in that order. Note: The server can only be started without "
705
- "certificates by enabling the `--insecure` flag.",
706
570
  )
707
571
  parser.add_argument(
708
572
  "--database",
@@ -714,116 +578,47 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
714
578
  default=DATABASE,
715
579
  )
716
580
  parser.add_argument(
717
- "--require-client-authentication",
718
- nargs=3,
719
- metavar=("CLIENT_KEYS", "SERVER_PRIVATE_KEY", "SERVER_PUBLIC_KEY"),
581
+ "--auth-list-public-keys",
720
582
  type=str,
721
- help="Provide three file paths: (1) a .csv file containing a list of "
722
- "known client public keys for authentication, (2) the server's private "
723
- "key file, and (3) the server's public key file.",
583
+ help="A CSV file (as a path str) containing a list of known public "
584
+ "keys to enable authentication.",
585
+ )
586
+ parser.add_argument(
587
+ "--auth-superlink-private-key",
588
+ type=str,
589
+ help="The SuperLink's private key (as a path str) to enable authentication.",
590
+ )
591
+ parser.add_argument(
592
+ "--auth-superlink-public-key",
593
+ type=str,
594
+ help="The SuperLink's public key (as a path str) to enable authentication.",
724
595
  )
725
596
 
726
597
 
727
598
  def _add_args_driver_api(parser: argparse.ArgumentParser) -> None:
728
599
  parser.add_argument(
729
600
  "--driver-api-address",
730
- help="Driver API (gRPC) server address (IPv4, IPv6, or a domain name)",
601
+ help="Driver API (gRPC) server address (IPv4, IPv6, or a domain name).",
731
602
  default=ADDRESS_DRIVER_API,
732
603
  )
733
604
 
734
605
 
735
606
  def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None:
736
607
  # Fleet API transport layer type
737
- ex_group = parser.add_mutually_exclusive_group()
738
- ex_group.add_argument(
739
- "--grpc-rere",
740
- action="store_const",
741
- dest="fleet_api_type",
742
- const=TRANSPORT_TYPE_GRPC_RERE,
608
+ parser.add_argument(
609
+ "--fleet-api-type",
743
610
  default=TRANSPORT_TYPE_GRPC_RERE,
744
- help="Start a Fleet API server (gRPC-rere)",
745
- )
746
- ex_group.add_argument(
747
- "--rest",
748
- action="store_const",
749
- dest="fleet_api_type",
750
- const=TRANSPORT_TYPE_REST,
751
- help="Start a Fleet API server (REST, experimental)",
752
- )
753
-
754
- ex_group.add_argument(
755
- "--vce",
756
- action="store_const",
757
- dest="fleet_api_type",
758
- const=TRANSPORT_TYPE_VCE,
759
- help="Start a Fleet API server (VirtualClientEngine)",
760
- )
761
-
762
- # Fleet API gRPC-rere options
763
- grpc_rere_group = parser.add_argument_group(
764
- "Fleet API (gRPC-rere) server options", ""
765
- )
766
- grpc_rere_group.add_argument(
767
- "--grpc-rere-fleet-api-address",
768
- help="Fleet API (gRPC-rere) server address (IPv4, IPv6, or a domain name)",
769
- default=ADDRESS_FLEET_API_GRPC_RERE,
770
- )
771
-
772
- # Fleet API REST options
773
- rest_group = parser.add_argument_group("Fleet API (REST) server options", "")
774
- rest_group.add_argument(
775
- "--rest-fleet-api-address",
776
- help="Fleet API (REST) server address (IPv4, IPv6, or a domain name)",
777
- default=ADDRESS_FLEET_API_REST,
778
- )
779
- rest_group.add_argument(
780
- "--ssl-certfile",
781
- help="Fleet API (REST) server SSL certificate file (as a path str), "
782
- "needed for using 'https'.",
783
- default=None,
784
- )
785
- rest_group.add_argument(
786
- "--ssl-keyfile",
787
- help="Fleet API (REST) server SSL private key file (as a path str), "
788
- "needed for using 'https'.",
789
- default=None,
790
- )
791
- rest_group.add_argument(
792
- "--rest-fleet-api-workers",
793
- help="Set the number of concurrent workers for the Fleet API REST server.",
794
- type=int,
795
- default=1,
796
- )
797
-
798
- # Fleet API VCE options
799
- vce_group = parser.add_argument_group("Fleet API (VCE) server options", "")
800
- vce_group.add_argument(
801
- "--client-app",
802
- help="For example: `client:app` or `project.package.module:wrapper.app`.",
803
- )
804
- vce_group.add_argument(
805
- "--num-supernodes",
806
- type=int,
807
- help="Number of simulated SuperNodes.",
808
- )
809
- vce_group.add_argument(
810
- "--backend",
811
- default="ray",
812
611
  type=str,
813
- help="Simulation backend that executes the ClientApp.",
612
+ choices=[TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST],
613
+ help="Start a gRPC-rere or REST (experimental) Fleet API server.",
814
614
  )
815
- vce_group.add_argument(
816
- "--backend-config",
817
- type=str,
818
- default='{"client_resources": {"num_cpus":1, "num_gpus":0.0}, "tensorflow": 0}',
819
- help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
820
- "configure a backend. Values supported in <value> are those included by "
821
- "`flwr.common.typing.ConfigsRecordValues`. ",
615
+ parser.add_argument(
616
+ "--fleet-api-address",
617
+ help="Fleet API server address (IPv4, IPv6, or a domain name).",
822
618
  )
823
619
  parser.add_argument(
824
- "--app-dir",
825
- default="",
826
- help="Add specified directory to the PYTHONPATH and load"
827
- "ClientApp from there."
828
- " Default: current working directory.",
620
+ "--fleet-api-num-workers",
621
+ default=1,
622
+ type=int,
623
+ help="Set the number of concurrent workers for the Fleet API server.",
829
624
  )