flwr 1.14.0__py3-none-any.whl → 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. flwr/cli/auth_plugin/__init__.py +31 -0
  2. flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
  3. flwr/cli/cli_user_auth_interceptor.py +6 -2
  4. flwr/cli/config_utils.py +24 -147
  5. flwr/cli/constant.py +27 -0
  6. flwr/cli/install.py +1 -1
  7. flwr/cli/log.py +18 -3
  8. flwr/cli/login/login.py +43 -8
  9. flwr/cli/ls.py +14 -5
  10. flwr/cli/new/templates/app/README.md.tpl +3 -2
  11. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  12. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  13. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  14. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
  15. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  16. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
  17. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
  18. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
  19. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  20. flwr/cli/run/run.py +21 -11
  21. flwr/cli/stop.py +13 -4
  22. flwr/cli/utils.py +54 -40
  23. flwr/client/app.py +36 -48
  24. flwr/client/clientapp/app.py +19 -25
  25. flwr/client/clientapp/utils.py +1 -1
  26. flwr/client/grpc_client/connection.py +1 -12
  27. flwr/client/grpc_rere_client/client_interceptor.py +19 -119
  28. flwr/client/grpc_rere_client/connection.py +46 -36
  29. flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
  30. flwr/client/message_handler/task_handler.py +0 -17
  31. flwr/client/rest_client/connection.py +34 -26
  32. flwr/client/supernode/app.py +18 -72
  33. flwr/common/args.py +25 -47
  34. flwr/common/auth_plugin/auth_plugin.py +34 -23
  35. flwr/common/config.py +166 -16
  36. flwr/common/constant.py +22 -9
  37. flwr/common/differential_privacy.py +2 -1
  38. flwr/common/exit/__init__.py +24 -0
  39. flwr/common/exit/exit.py +99 -0
  40. flwr/common/exit/exit_code.py +93 -0
  41. flwr/common/exit_handlers.py +24 -10
  42. flwr/common/grpc.py +167 -4
  43. flwr/common/logger.py +26 -7
  44. flwr/common/record/recordset.py +1 -1
  45. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
  46. flwr/common/serde.py +6 -4
  47. flwr/common/typing.py +20 -0
  48. flwr/proto/clientappio_pb2.py +1 -1
  49. flwr/proto/error_pb2.py +1 -1
  50. flwr/proto/exec_pb2.py +13 -25
  51. flwr/proto/exec_pb2.pyi +27 -54
  52. flwr/proto/fab_pb2.py +1 -1
  53. flwr/proto/fleet_pb2.py +31 -31
  54. flwr/proto/fleet_pb2.pyi +23 -23
  55. flwr/proto/fleet_pb2_grpc.py +30 -30
  56. flwr/proto/fleet_pb2_grpc.pyi +20 -20
  57. flwr/proto/grpcadapter_pb2.py +1 -1
  58. flwr/proto/log_pb2.py +1 -1
  59. flwr/proto/message_pb2.py +1 -1
  60. flwr/proto/node_pb2.py +3 -3
  61. flwr/proto/node_pb2.pyi +1 -4
  62. flwr/proto/recordset_pb2.py +1 -1
  63. flwr/proto/run_pb2.py +1 -1
  64. flwr/proto/serverappio_pb2.py +24 -25
  65. flwr/proto/serverappio_pb2.pyi +26 -32
  66. flwr/proto/serverappio_pb2_grpc.py +28 -28
  67. flwr/proto/serverappio_pb2_grpc.pyi +16 -16
  68. flwr/proto/simulationio_pb2.py +1 -1
  69. flwr/proto/task_pb2.py +1 -1
  70. flwr/proto/transport_pb2.py +1 -1
  71. flwr/server/app.py +116 -128
  72. flwr/server/compat/app_utils.py +0 -1
  73. flwr/server/compat/driver_client_proxy.py +1 -2
  74. flwr/server/driver/grpc_driver.py +32 -27
  75. flwr/server/driver/inmemory_driver.py +2 -1
  76. flwr/server/serverapp/app.py +12 -10
  77. flwr/server/superlink/driver/serverappio_grpc.py +1 -1
  78. flwr/server/superlink/driver/serverappio_servicer.py +74 -48
  79. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
  80. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
  81. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -24
  82. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +97 -168
  83. flwr/server/superlink/fleet/message_handler/message_handler.py +37 -24
  84. flwr/server/superlink/fleet/rest_rere/rest_api.py +16 -18
  85. flwr/server/superlink/fleet/vce/vce_api.py +2 -2
  86. flwr/server/superlink/linkstate/in_memory_linkstate.py +45 -75
  87. flwr/server/superlink/linkstate/linkstate.py +17 -38
  88. flwr/server/superlink/linkstate/sqlite_linkstate.py +81 -145
  89. flwr/server/superlink/linkstate/utils.py +18 -8
  90. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  91. flwr/server/utils/validator.py +9 -34
  92. flwr/simulation/app.py +4 -6
  93. flwr/simulation/legacy_app.py +4 -2
  94. flwr/simulation/run_simulation.py +1 -1
  95. flwr/simulation/simulationio_connection.py +2 -1
  96. flwr/superexec/exec_grpc.py +1 -1
  97. flwr/superexec/exec_servicer.py +23 -2
  98. {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/METADATA +8 -8
  99. {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/RECORD +102 -96
  100. {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/LICENSE +0 -0
  101. {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/WHEEL +0 -0
  102. {flwr-1.14.0.dist-info → flwr-1.15.0.dist-info}/entry_points.txt +0 -0
flwr/server/app.py CHANGED
@@ -18,7 +18,9 @@
18
18
  import argparse
19
19
  import csv
20
20
  import importlib.util
21
- import subprocess
21
+ import multiprocessing
22
+ import multiprocessing.context
23
+ import os
22
24
  import sys
23
25
  import threading
24
26
  from collections.abc import Sequence
@@ -29,12 +31,8 @@ from typing import Any, Optional
29
31
 
30
32
  import grpc
31
33
  import yaml
32
- from cryptography.exceptions import UnsupportedAlgorithm
33
34
  from cryptography.hazmat.primitives.asymmetric import ec
34
- from cryptography.hazmat.primitives.serialization import (
35
- load_ssh_private_key,
36
- load_ssh_public_key,
37
- )
35
+ from cryptography.hazmat.primitives.serialization import load_ssh_public_key
38
36
 
39
37
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
40
38
  from flwr.common.address import parse_address
@@ -42,7 +40,7 @@ from flwr.common.args import try_obtain_server_certificates
42
40
  from flwr.common.auth_plugin import ExecAuthPlugin
43
41
  from flwr.common.config import get_flwr_dir, parse_config_args
44
42
  from flwr.common.constant import (
45
- AUTH_TYPE,
43
+ AUTH_TYPE_KEY,
46
44
  CLIENT_OCTET,
47
45
  EXEC_API_DEFAULT_SERVER_ADDRESS,
48
46
  FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
@@ -50,7 +48,6 @@ from flwr.common.constant import (
50
48
  FLEET_API_REST_DEFAULT_ADDRESS,
51
49
  ISOLATION_MODE_PROCESS,
52
50
  ISOLATION_MODE_SUBPROCESS,
53
- MISSING_EXTRA_REST,
54
51
  SERVER_OCTET,
55
52
  SERVERAPPIO_API_DEFAULT_SERVER_ADDRESS,
56
53
  SIMULATIONIO_API_DEFAULT_SERVER_ADDRESS,
@@ -58,16 +55,19 @@ from flwr.common.constant import (
58
55
  TRANSPORT_TYPE_GRPC_RERE,
59
56
  TRANSPORT_TYPE_REST,
60
57
  )
58
+ from flwr.common.exit import ExitCode, flwr_exit
61
59
  from flwr.common.exit_handlers import register_exit_handlers
60
+ from flwr.common.grpc import generic_create_grpc_server
62
61
  from flwr.common.logger import log, warn_deprecated_feature
63
62
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
64
- private_key_to_bytes,
65
63
  public_key_to_bytes,
66
64
  )
67
65
  from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
68
66
  add_FleetServicer_to_server,
69
67
  )
70
68
  from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
69
+ from flwr.server.serverapp.app import flwr_serverapp
70
+ from flwr.simulation.app import flwr_simulation
71
71
  from flwr.superexec.app import load_executor
72
72
  from flwr.superexec.exec_grpc import run_exec_api_grpc
73
73
 
@@ -79,10 +79,7 @@ from .strategy import Strategy
79
79
  from .superlink.driver.serverappio_grpc import run_serverappio_api_grpc
80
80
  from .superlink.ffs.ffs_factory import FfsFactory
81
81
  from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
82
- from .superlink.fleet.grpc_bidi.grpc_server import (
83
- generic_create_grpc_server,
84
- start_grpc_server,
85
- )
82
+ from .superlink.fleet.grpc_bidi.grpc_server import start_grpc_server
86
83
  from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
87
84
  from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
88
85
  from .superlink.linkstate import LinkStateFactory
@@ -226,6 +223,13 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals
226
223
  "enabled" if certificates is not None else "disabled",
227
224
  )
228
225
 
226
+ # Graceful shutdown
227
+ register_exit_handlers(
228
+ event_type=EventType.START_SERVER_LEAVE,
229
+ exit_message="Flower server terminated gracefully.",
230
+ grpc_servers=[grpc_server],
231
+ )
232
+
229
233
  # Start training
230
234
  hist = run_fl(
231
235
  server=initialized_server,
@@ -261,13 +265,16 @@ def run_superlink() -> None:
261
265
  simulationio_address, _, _ = _format_address(args.simulationio_api_address)
262
266
 
263
267
  # Obtain certificates
264
- certificates = try_obtain_server_certificates(args, args.fleet_api_type)
268
+ certificates = try_obtain_server_certificates(args)
269
+
270
+ # Disable the user auth TLS check if args.disable_oidc_tls_cert_verification is
271
+ # provided
272
+ verify_tls_cert = not getattr(args, "disable_oidc_tls_cert_verification", None)
265
273
 
266
- user_auth_config = _try_obtain_user_auth_config(args)
267
274
  auth_plugin: Optional[ExecAuthPlugin] = None
268
- # user_auth_config is None only if the args.user_auth_config is not provided
269
- if user_auth_config is not None:
270
- auth_plugin = _try_obtain_exec_auth_plugin(user_auth_config)
275
+ # Load the auth plugin if the args.user_auth_config is provided
276
+ if cfg_path := getattr(args, "user_auth_config", None):
277
+ auth_plugin = _try_obtain_exec_auth_plugin(Path(cfg_path), verify_tls_cert)
271
278
 
272
279
  # Initialize StateFactory
273
280
  state_factory = LinkStateFactory(args.database)
@@ -293,7 +300,7 @@ def run_superlink() -> None:
293
300
  # Determine Exec plugin
294
301
  # If simulation is used, don't start ServerAppIo and Fleet APIs
295
302
  sim_exec = executor.__class__.__qualname__ == "SimulationEngine"
296
- bckg_threads = []
303
+ bckg_threads: list[threading.Thread] = []
297
304
 
298
305
  if sim_exec:
299
306
  simulationio_server: grpc.Server = run_simulationio_api_grpc(
@@ -344,48 +351,40 @@ def run_superlink() -> None:
344
351
  and importlib.util.find_spec("starlette")
345
352
  and importlib.util.find_spec("uvicorn")
346
353
  ) is None:
347
- sys.exit(MISSING_EXTRA_REST)
348
-
349
- _, ssl_certfile, ssl_keyfile = (
350
- certificates if certificates is not None else (None, None, None)
351
- )
354
+ flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
352
355
 
353
356
  fleet_thread = threading.Thread(
354
357
  target=_run_fleet_api_rest,
355
358
  args=(
356
359
  host,
357
360
  port,
358
- ssl_keyfile,
359
- ssl_certfile,
361
+ args.ssl_keyfile,
362
+ args.ssl_certfile,
360
363
  state_factory,
361
364
  ffs_factory,
362
365
  num_workers,
363
366
  ),
367
+ daemon=True,
364
368
  )
365
369
  fleet_thread.start()
366
370
  bckg_threads.append(fleet_thread)
367
371
  elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
368
- maybe_keys = _try_setup_node_authentication(args, certificates)
369
- interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
370
- if maybe_keys is not None:
371
- (
372
- node_public_keys,
373
- server_private_key,
374
- server_public_key,
375
- ) = maybe_keys
372
+ node_public_keys = _try_load_public_keys_node_authentication(args)
373
+ auto_auth = True
374
+ if node_public_keys is not None:
375
+ auto_auth = False
376
376
  state = state_factory.state()
377
- state.clear_supernode_auth_keys_and_credentials()
377
+ state.clear_supernode_auth_keys()
378
378
  state.store_node_public_keys(node_public_keys)
379
- state.store_server_private_public_key(
380
- private_key_to_bytes(server_private_key),
381
- public_key_to_bytes(server_public_key),
382
- )
383
379
  log(
384
380
  INFO,
385
381
  "Node authentication enabled with %d known public keys",
386
382
  len(node_public_keys),
387
383
  )
388
- interceptors = [AuthenticateServerInterceptor(state_factory)]
384
+ else:
385
+ log(DEBUG, "Automatic node authentication enabled")
386
+
387
+ interceptors = [AuthenticateServerInterceptor(state_factory, auto_auth)]
389
388
 
390
389
  fleet_server = _run_fleet_api_grpc_rere(
391
390
  address=fleet_address,
@@ -427,6 +426,7 @@ def run_superlink() -> None:
427
426
  address,
428
427
  cmd,
429
428
  ),
429
+ daemon=True,
430
430
  )
431
431
  scheduler_th.start()
432
432
  bckg_threads.append(scheduler_th)
@@ -434,17 +434,37 @@ def run_superlink() -> None:
434
434
  # Graceful shutdown
435
435
  register_exit_handlers(
436
436
  event_type=EventType.RUN_SUPERLINK_LEAVE,
437
+ exit_message="SuperLink terminated gracefully.",
437
438
  grpc_servers=grpc_servers,
438
- bckg_threads=bckg_threads,
439
439
  )
440
440
 
441
- # Block
442
- while True:
443
- if bckg_threads:
444
- for thread in bckg_threads:
445
- if not thread.is_alive():
446
- sys.exit(1)
447
- exec_server.wait_for_termination(timeout=1)
441
+ # Block until a thread exits prematurely
442
+ while all(thread.is_alive() for thread in bckg_threads):
443
+ sleep(0.1)
444
+
445
+ # Exit if any thread has exited prematurely
446
+ # This code will not be reached if the SuperLink stops gracefully
447
+ flwr_exit(ExitCode.SUPERLINK_THREAD_CRASH)
448
+
449
+
450
+ def _run_flwr_command(args: list[str], main_pid: int) -> None:
451
+ # Monitor the main process in case of SIGKILL
452
+ def main_process_monitor() -> None:
453
+ while True:
454
+ sleep(1)
455
+ if os.getppid() != main_pid:
456
+ os.kill(os.getpid(), 9)
457
+
458
+ threading.Thread(target=main_process_monitor, daemon=True).start()
459
+
460
+ # Run the command
461
+ sys.argv = args
462
+ if args[0] == "flwr-serverapp":
463
+ flwr_serverapp()
464
+ elif args[0] == "flwr-simulation":
465
+ flwr_simulation()
466
+ else:
467
+ raise ValueError(f"Unknown command: {args[0]}")
448
468
 
449
469
 
450
470
  def _flwr_scheduler(
@@ -454,15 +474,18 @@ def _flwr_scheduler(
454
474
  cmd: str,
455
475
  ) -> None:
456
476
  log(DEBUG, "Started %s scheduler thread.", cmd)
457
-
458
477
  state = state_factory.state()
478
+ run_id_to_proc: dict[int, multiprocessing.context.SpawnProcess] = {}
479
+
480
+ # Use the "spawn" start method for multiprocessing.
481
+ mp_spawn_context = multiprocessing.get_context("spawn")
459
482
 
460
483
  # Periodically check for a pending run in the LinkState
461
484
  while True:
462
- sleep(3)
485
+ sleep(0.1)
463
486
  pending_run_id = state.get_pending_run_id()
464
487
 
465
- if pending_run_id:
488
+ if pending_run_id and pending_run_id not in run_id_to_proc:
466
489
 
467
490
  log(
468
491
  INFO,
@@ -479,50 +502,45 @@ def _flwr_scheduler(
479
502
  "--insecure",
480
503
  ]
481
504
 
482
- subprocess.Popen( # pylint: disable=consider-using-with
483
- command,
484
- text=True,
505
+ proc = mp_spawn_context.Process(
506
+ target=_run_flwr_command, args=(command, os.getpid()), daemon=True
485
507
  )
508
+ proc.start()
509
+
510
+ # Store the process
511
+ run_id_to_proc[pending_run_id] = proc
512
+
513
+ # Clean up finished processes
514
+ for run_id, proc in list(run_id_to_proc.items()):
515
+ if not proc.is_alive():
516
+ del run_id_to_proc[run_id]
486
517
 
487
518
 
488
519
  def _format_address(address: str) -> tuple[str, str, int]:
489
520
  parsed_address = parse_address(address)
490
521
  if not parsed_address:
491
- sys.exit(
492
- f"Address ({address}) cannot be parsed (expected: URL or IPv4 or IPv6)."
522
+ flwr_exit(
523
+ ExitCode.COMMON_ADDRESS_INVALID,
524
+ f"Address ({address}) cannot be parsed.",
493
525
  )
494
526
  host, port, is_v6 = parsed_address
495
527
  return (f"[{host}]:{port}" if is_v6 else f"{host}:{port}", host, port)
496
528
 
497
529
 
498
- def _try_setup_node_authentication(
530
+ def _try_load_public_keys_node_authentication(
499
531
  args: argparse.Namespace,
500
- certificates: Optional[tuple[bytes, bytes, bytes]],
501
- ) -> Optional[tuple[set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]:
502
- if (
503
- not args.auth_list_public_keys
504
- and not args.auth_superlink_private_key
505
- and not args.auth_superlink_public_key
506
- ):
507
- return None
508
-
509
- if (
510
- not args.auth_list_public_keys
511
- or not args.auth_superlink_private_key
512
- or not args.auth_superlink_public_key
513
- ):
514
- sys.exit(
515
- "Authentication requires providing file paths for "
516
- "'--auth-list-public-keys', '--auth-superlink-private-key' and "
517
- "'--auth-superlink-public-key'. Provide all three to enable authentication."
532
+ ) -> Optional[set[bytes]]:
533
+ """Return a set of node public keys."""
534
+ if args.auth_superlink_private_key or args.auth_superlink_public_key:
535
+ log(
536
+ WARN,
537
+ "The `--auth-superlink-private-key` and `--auth-superlink-public-key` "
538
+ "arguments are deprecated and will be removed in a future release. Node "
539
+ "authentication no longer requires these arguments.",
518
540
  )
519
541
 
520
- if certificates is None:
521
- sys.exit(
522
- "Authentication requires secure connections. "
523
- "Please provide certificate paths to `--ssl-certfile`, "
524
- "`--ssl-keyfile`, and `—-ssl-ca-certfile` and try again."
525
- )
542
+ if not args.auth_list_public_keys:
543
+ return None
526
544
 
527
545
  node_keys_file_path = Path(args.auth_list_public_keys)
528
546
  if not node_keys_file_path.exists():
@@ -535,35 +553,6 @@ def _try_setup_node_authentication(
535
553
 
536
554
  node_public_keys: set[bytes] = set()
537
555
 
538
- try:
539
- ssh_private_key = load_ssh_private_key(
540
- Path(args.auth_superlink_private_key).read_bytes(),
541
- None,
542
- )
543
- if not isinstance(ssh_private_key, ec.EllipticCurvePrivateKey):
544
- raise ValueError()
545
- except (ValueError, UnsupportedAlgorithm):
546
- sys.exit(
547
- "Error: Unable to parse the private key file in "
548
- "'--auth-superlink-private-key'. Authentication requires elliptic "
549
- "curve private and public key pair. Please ensure that the file "
550
- "path points to a valid private key file and try again."
551
- )
552
-
553
- try:
554
- ssh_public_key = load_ssh_public_key(
555
- Path(args.auth_superlink_public_key).read_bytes()
556
- )
557
- if not isinstance(ssh_public_key, ec.EllipticCurvePublicKey):
558
- raise ValueError()
559
- except (ValueError, UnsupportedAlgorithm):
560
- sys.exit(
561
- "Error: Unable to parse the public key file in "
562
- "'--auth-superlink-public-key'. Authentication requires elliptic "
563
- "curve private and public key pair. Please ensure that the file "
564
- "path points to a valid public key file and try again."
565
- )
566
-
567
556
  with open(node_keys_file_path, newline="", encoding="utf-8") as csvfile:
568
557
  reader = csv.reader(csvfile)
569
558
  for row in reader:
@@ -577,28 +566,27 @@ def _try_setup_node_authentication(
577
566
  "file. Please ensure that the CSV file path points to a valid "
578
567
  "known SSH public keys files and try again."
579
568
  )
580
- return (
581
- node_public_keys,
582
- ssh_private_key,
583
- ssh_public_key,
584
- )
585
-
569
+ return node_public_keys
586
570
 
587
- def _try_obtain_user_auth_config(args: argparse.Namespace) -> Optional[dict[str, Any]]:
588
- if getattr(args, "user_auth_config", None) is not None:
589
- with open(args.user_auth_config, encoding="utf-8") as file:
590
- config: dict[str, Any] = yaml.safe_load(file)
591
- return config
592
- return None
593
571
 
572
+ def _try_obtain_exec_auth_plugin(
573
+ config_path: Path, verify_tls_cert: bool
574
+ ) -> Optional[ExecAuthPlugin]:
575
+ # Load YAML file
576
+ with config_path.open("r", encoding="utf-8") as file:
577
+ config: dict[str, Any] = yaml.safe_load(file)
594
578
 
595
- def _try_obtain_exec_auth_plugin(config: dict[str, Any]) -> Optional[ExecAuthPlugin]:
579
+ # Load authentication configuration
596
580
  auth_config: dict[str, Any] = config.get("authentication", {})
597
- auth_type: str = auth_config.get(AUTH_TYPE, "")
581
+ auth_type: str = auth_config.get(AUTH_TYPE_KEY, "")
582
+
583
+ # Load authentication plugin
598
584
  try:
599
585
  all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins()
600
586
  auth_plugin_class = all_plugins[auth_type]
601
- return auth_plugin_class(config=auth_config)
587
+ return auth_plugin_class(
588
+ user_auth_config_path=config_path, verify_tls_cert=verify_tls_cert
589
+ )
602
590
  except KeyError:
603
591
  if auth_type != "":
604
592
  sys.exit(
@@ -681,7 +669,7 @@ def _run_fleet_api_rest(
681
669
 
682
670
  from flwr.server.superlink.fleet.rest_rere.rest_api import app as fast_api_app
683
671
  except ModuleNotFoundError:
684
- sys.exit(MISSING_EXTRA_REST)
672
+ flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
685
673
 
686
674
  log(INFO, "Starting Flower REST server")
687
675
 
@@ -791,12 +779,12 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
791
779
  parser.add_argument(
792
780
  "--auth-superlink-private-key",
793
781
  type=str,
794
- help="The SuperLink's private key (as a path str) to enable authentication.",
782
+ help="This argument is deprecated and will be removed in a future release.",
795
783
  )
796
784
  parser.add_argument(
797
785
  "--auth-superlink-public-key",
798
786
  type=str,
799
- help="The SuperLink's public key (as a path str) to enable authentication.",
787
+ help="This argument is deprecated and will be removed in a future release.",
800
788
  )
801
789
 
802
790
 
@@ -95,7 +95,6 @@ def _update_client_manager(
95
95
  client_proxy = DriverClientProxy(
96
96
  node_id=node_id,
97
97
  driver=driver,
98
- anonymous=False,
99
98
  run_id=driver.run.run_id,
100
99
  )
101
100
  if client_manager.register(client_proxy):
@@ -28,12 +28,11 @@ from ..driver.driver import Driver
28
28
  class DriverClientProxy(ClientProxy):
29
29
  """Flower client proxy which delegates work using the Driver API."""
30
30
 
31
- def __init__(self, node_id: int, driver: Driver, anonymous: bool, run_id: int):
31
+ def __init__(self, node_id: int, driver: Driver, run_id: int):
32
32
  super().__init__(str(node_id))
33
33
  self.node_id = node_id
34
34
  self.driver = driver
35
35
  self.run_id = run_id
36
- self.anonymous = anonymous
37
36
 
38
37
  def get_properties(
39
38
  self,
@@ -24,29 +24,32 @@ from typing import Optional, cast
24
24
  import grpc
25
25
 
26
26
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
27
- from flwr.common.constant import SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS
28
- from flwr.common.grpc import create_channel
27
+ from flwr.common.constant import (
28
+ SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
29
+ SUPERLINK_NODE_ID,
30
+ )
31
+ from flwr.common.grpc import create_channel, on_channel_state_change
29
32
  from flwr.common.logger import log
30
33
  from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
31
- from flwr.common.serde import message_from_taskres, message_to_taskins, run_from_proto
34
+ from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
32
35
  from flwr.common.typing import Run
36
+ from flwr.proto.message_pb2 import Message as ProtoMessage # pylint: disable=E0611
33
37
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
34
38
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
35
39
  from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
36
40
  GetNodesRequest,
37
41
  GetNodesResponse,
38
- PullTaskResRequest,
39
- PullTaskResResponse,
40
- PushTaskInsRequest,
41
- PushTaskInsResponse,
42
+ PullResMessagesRequest,
43
+ PullResMessagesResponse,
44
+ PushInsMessagesRequest,
45
+ PushInsMessagesResponse,
42
46
  )
43
47
  from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
44
- from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
45
48
 
46
49
  from .driver import Driver
47
50
 
48
51
  ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
49
- [Driver] Error: Not connected.
52
+ [flwr-serverapp] Error: Not connected.
50
53
 
51
54
  Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
52
55
  `GrpcDriverStub` methods.
@@ -76,7 +79,7 @@ class GrpcDriver(Driver):
76
79
  self._run: Optional[Run] = None
77
80
  self._grpc_stub: Optional[ServerAppIoStub] = None
78
81
  self._channel: Optional[grpc.Channel] = None
79
- self.node = Node(node_id=0, anonymous=True)
82
+ self.node = Node(node_id=SUPERLINK_NODE_ID)
80
83
  self._retry_invoker = _make_simple_grpc_retry_invoker()
81
84
 
82
85
  @property
@@ -97,9 +100,10 @@ class GrpcDriver(Driver):
97
100
  insecure=(self._cert is None),
98
101
  root_certificates=self._cert,
99
102
  )
103
+ self._channel.subscribe(on_channel_state_change)
100
104
  self._grpc_stub = ServerAppIoStub(self._channel)
101
105
  _wrap_stub(self._grpc_stub, self._retry_invoker)
102
- log(DEBUG, "[Driver] Connected to %s", self._addr)
106
+ log(DEBUG, "[flwr-serverapp] Connected to %s", self._addr)
103
107
 
104
108
  def _disconnect(self) -> None:
105
109
  """Disconnect from the ServerAppIo API."""
@@ -110,7 +114,7 @@ class GrpcDriver(Driver):
110
114
  self._channel = None
111
115
  self._grpc_stub = None
112
116
  channel.close()
113
- log(DEBUG, "[Driver] Disconnected")
117
+ log(DEBUG, "[flwr-serverapp] Disconnected")
114
118
 
115
119
  def set_run(self, run_id: int) -> None:
116
120
  """Set the run."""
@@ -193,22 +197,22 @@ class GrpcDriver(Driver):
193
197
  This method takes an iterable of messages and sends each message
194
198
  to the node specified in `dst_node_id`.
195
199
  """
196
- # Construct TaskIns
197
- task_ins_list: list[TaskIns] = []
200
+ # Construct Messages
201
+ message_proto_list: list[ProtoMessage] = []
198
202
  for msg in messages:
199
203
  # Check message
200
204
  self._check_message(msg)
201
- # Convert Message to TaskIns
202
- taskins = message_to_taskins(msg)
205
+ # Convert to proto
206
+ msg_proto = message_to_proto(msg)
203
207
  # Add to list
204
- task_ins_list.append(taskins)
208
+ message_proto_list.append(msg_proto)
205
209
  # Call GrpcDriverStub method
206
- res: PushTaskInsResponse = self._stub.PushTaskIns(
207
- PushTaskInsRequest(
208
- task_ins_list=task_ins_list, run_id=cast(Run, self._run).run_id
210
+ res: PushInsMessagesResponse = self._stub.PushMessages(
211
+ PushInsMessagesRequest(
212
+ messages_list=message_proto_list, run_id=cast(Run, self._run).run_id
209
213
  )
210
214
  )
211
- return list(res.task_ids)
215
+ return list(res.message_ids)
212
216
 
213
217
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
214
218
  """Pull messages based on message IDs.
@@ -216,14 +220,15 @@ class GrpcDriver(Driver):
216
220
  This method is used to collect messages from the SuperLink that correspond to a
217
221
  set of given message IDs.
218
222
  """
219
- # Pull TaskRes
220
- res: PullTaskResResponse = self._stub.PullTaskRes(
221
- PullTaskResRequest(
222
- node=self.node, task_ids=message_ids, run_id=cast(Run, self._run).run_id
223
+ # Pull Messages
224
+ res: PullResMessagesResponse = self._stub.PullMessages(
225
+ PullResMessagesRequest(
226
+ message_ids=message_ids,
227
+ run_id=cast(Run, self._run).run_id,
223
228
  )
224
229
  )
225
- # Convert TaskRes to Message
226
- msgs = [message_from_taskres(taskres) for taskres in res.task_res_list]
230
+ # Convert Message from Protobuf representation
231
+ msgs = [message_from_proto(msg_proto) for msg_proto in res.messages_list]
227
232
  return msgs
228
233
 
229
234
  def send_and_receive(
@@ -22,6 +22,7 @@ from typing import Optional, cast
22
22
  from uuid import UUID
23
23
 
24
24
  from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
25
+ from flwr.common.constant import SUPERLINK_NODE_ID
25
26
  from flwr.common.serde import message_from_taskres, message_to_taskins
26
27
  from flwr.common.typing import Run
27
28
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
@@ -49,7 +50,7 @@ class InMemoryDriver(Driver):
49
50
  self._run: Optional[Run] = None
50
51
  self.state = state_factory.state()
51
52
  self.pull_interval = pull_interval
52
- self.node = Node(node_id=0, anonymous=True)
53
+ self.node = Node(node_id=SUPERLINK_NODE_ID)
53
54
 
54
55
  def _check_message(self, message: Message) -> None:
55
56
  # Check if the message is valid
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  import argparse
19
- import sys
20
19
  from logging import DEBUG, ERROR, INFO
21
20
  from pathlib import Path
22
21
  from queue import Queue
@@ -38,6 +37,7 @@ from flwr.common.constant import (
38
37
  Status,
39
38
  SubStatus,
40
39
  )
40
+ from flwr.common.exit import ExitCode, flwr_exit
41
41
  from flwr.common.logger import (
42
42
  log,
43
43
  mirror_output_to_queue,
@@ -72,19 +72,18 @@ def flwr_serverapp() -> None:
72
72
 
73
73
  args = _parse_args_run_flwr_serverapp().parse_args()
74
74
 
75
- log(INFO, "Starting Flower ServerApp")
75
+ log(INFO, "Start `flwr-serverapp` process")
76
76
 
77
77
  if not args.insecure:
78
- log(
79
- ERROR,
80
- "`flwr-serverapp` does not support TLS yet. "
81
- "Please use the '--insecure' flag.",
78
+ flwr_exit(
79
+ ExitCode.COMMON_TLS_NOT_SUPPORTED,
80
+ "`flwr-serverapp` does not support TLS yet.",
82
81
  )
83
- sys.exit(1)
84
82
 
85
83
  log(
86
84
  DEBUG,
87
- "Starting isolated `ServerApp` connected to SuperLink's ServerAppIo API at %s",
85
+ "`flwr-serverapp` will attempt to connect to SuperLink's "
86
+ "ServerAppIo API at %s",
88
87
  args.serverappio_api_address,
89
88
  )
90
89
  run_serverapp(
@@ -117,11 +116,13 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
117
116
  log_uploader = None
118
117
  success = True
119
118
  hash_run_id = None
119
+ run_status = None
120
120
  while True:
121
121
 
122
122
  try:
123
123
  # Pull ServerAppInputs from LinkState
124
124
  req = PullServerAppInputsRequest()
125
+ log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
125
126
  res: PullServerAppInputsResponse = driver._stub.PullServerAppInputs(req)
126
127
  if not res.HasField("run"):
127
128
  sleep(3)
@@ -144,7 +145,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
144
145
  stub=driver._stub,
145
146
  )
146
147
 
147
- log(DEBUG, "ServerApp process starts FAB installation.")
148
+ log(DEBUG, "[flwr-serverapp] Start FAB installation.")
148
149
  install_from_fab(fab.content, flwr_dir=flwr_dir_, skip_prompt=True)
149
150
 
150
151
  fab_id, fab_version = get_fab_metadata(fab.content)
@@ -165,7 +166,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
165
166
 
166
167
  log(
167
168
  DEBUG,
168
- "Flower will load ServerApp `%s` in %s",
169
+ "[flwr-serverapp] Will load ServerApp `%s` in %s",
169
170
  server_app_attr,
170
171
  app_path,
171
172
  )
@@ -191,6 +192,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212, disable=R0915
191
192
 
192
193
  # Send resulting context
193
194
  context_proto = context_to_proto(updated_context)
195
+ log(DEBUG, "[flwr-serverapp] Will push ServerAppOutputs")
194
196
  out_req = PushServerAppOutputsRequest(
195
197
  run_id=run.run_id, context=context_proto
196
198
  )