flwr 1.13.1__py3-none-any.whl → 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. flwr/cli/app.py +5 -0
  2. flwr/cli/auth_plugin/__init__.py +31 -0
  3. flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
  4. flwr/cli/build.py +1 -0
  5. flwr/cli/cli_user_auth_interceptor.py +90 -0
  6. flwr/cli/config_utils.py +43 -149
  7. flwr/cli/constant.py +27 -0
  8. flwr/cli/example.py +1 -0
  9. flwr/cli/install.py +2 -1
  10. flwr/cli/log.py +34 -37
  11. flwr/cli/login/__init__.py +22 -0
  12. flwr/cli/login/login.py +116 -0
  13. flwr/cli/ls.py +214 -106
  14. flwr/cli/new/__init__.py +1 -0
  15. flwr/cli/new/new.py +2 -1
  16. flwr/cli/new/templates/app/.gitignore.tpl +3 -0
  17. flwr/cli/new/templates/app/README.md.tpl +3 -2
  18. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  19. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +4 -4
  20. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  21. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
  22. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +3 -4
  23. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
  24. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
  25. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
  26. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  27. flwr/cli/run/__init__.py +1 -0
  28. flwr/cli/run/run.py +103 -43
  29. flwr/cli/stop.py +139 -0
  30. flwr/cli/utils.py +186 -8
  31. flwr/client/app.py +49 -50
  32. flwr/client/client.py +1 -32
  33. flwr/client/clientapp/app.py +23 -26
  34. flwr/client/clientapp/utils.py +2 -1
  35. flwr/client/grpc_adapter_client/connection.py +1 -1
  36. flwr/client/grpc_client/connection.py +2 -13
  37. flwr/client/grpc_rere_client/client_interceptor.py +19 -119
  38. flwr/client/grpc_rere_client/connection.py +59 -43
  39. flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
  40. flwr/client/message_handler/message_handler.py +1 -2
  41. flwr/client/message_handler/task_handler.py +0 -17
  42. flwr/client/mod/comms_mods.py +1 -0
  43. flwr/client/mod/localdp_mod.py +1 -1
  44. flwr/client/nodestate/__init__.py +1 -0
  45. flwr/client/nodestate/nodestate.py +1 -0
  46. flwr/client/nodestate/nodestate_factory.py +1 -0
  47. flwr/client/numpy_client.py +0 -44
  48. flwr/client/rest_client/connection.py +37 -29
  49. flwr/client/supernode/app.py +20 -74
  50. flwr/common/address.py +1 -0
  51. flwr/common/args.py +26 -47
  52. flwr/common/auth_plugin/__init__.py +24 -0
  53. flwr/common/auth_plugin/auth_plugin.py +122 -0
  54. flwr/common/config.py +169 -17
  55. flwr/common/constant.py +38 -9
  56. flwr/common/differential_privacy.py +2 -1
  57. flwr/common/exit/__init__.py +24 -0
  58. flwr/common/exit/exit.py +99 -0
  59. flwr/common/exit/exit_code.py +93 -0
  60. flwr/common/exit_handlers.py +24 -10
  61. flwr/common/grpc.py +167 -4
  62. flwr/common/logger.py +66 -7
  63. flwr/common/message.py +1 -0
  64. flwr/common/object_ref.py +57 -54
  65. flwr/common/pyproject.py +1 -0
  66. flwr/common/record/__init__.py +1 -0
  67. flwr/common/record/parametersrecord.py +1 -0
  68. flwr/common/record/recordset.py +1 -1
  69. flwr/common/retry_invoker.py +77 -0
  70. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
  71. flwr/common/secure_aggregation/secaggplus_utils.py +2 -2
  72. flwr/common/serde.py +6 -4
  73. flwr/common/telemetry.py +15 -4
  74. flwr/common/typing.py +32 -0
  75. flwr/common/version.py +1 -0
  76. flwr/proto/clientappio_pb2.py +1 -1
  77. flwr/proto/error_pb2.py +1 -1
  78. flwr/proto/exec_pb2.py +27 -15
  79. flwr/proto/exec_pb2.pyi +80 -2
  80. flwr/proto/exec_pb2_grpc.py +102 -0
  81. flwr/proto/exec_pb2_grpc.pyi +39 -0
  82. flwr/proto/fab_pb2.py +5 -5
  83. flwr/proto/fab_pb2.pyi +4 -1
  84. flwr/proto/fleet_pb2.py +31 -31
  85. flwr/proto/fleet_pb2.pyi +23 -23
  86. flwr/proto/fleet_pb2_grpc.py +30 -30
  87. flwr/proto/fleet_pb2_grpc.pyi +20 -20
  88. flwr/proto/grpcadapter_pb2.py +1 -1
  89. flwr/proto/log_pb2.py +1 -1
  90. flwr/proto/message_pb2.py +1 -1
  91. flwr/proto/node_pb2.py +3 -3
  92. flwr/proto/node_pb2.pyi +1 -4
  93. flwr/proto/recordset_pb2.py +1 -1
  94. flwr/proto/run_pb2.py +1 -1
  95. flwr/proto/serverappio_pb2.py +24 -25
  96. flwr/proto/serverappio_pb2.pyi +32 -32
  97. flwr/proto/serverappio_pb2_grpc.py +62 -28
  98. flwr/proto/serverappio_pb2_grpc.pyi +29 -16
  99. flwr/proto/simulationio_pb2.py +3 -3
  100. flwr/proto/simulationio_pb2_grpc.py +34 -0
  101. flwr/proto/simulationio_pb2_grpc.pyi +13 -0
  102. flwr/proto/task_pb2.py +1 -1
  103. flwr/proto/transport_pb2.py +1 -1
  104. flwr/server/app.py +152 -112
  105. flwr/server/compat/app_utils.py +7 -2
  106. flwr/server/compat/driver_client_proxy.py +1 -2
  107. flwr/server/driver/grpc_driver.py +38 -85
  108. flwr/server/driver/inmemory_driver.py +7 -2
  109. flwr/server/run_serverapp.py +8 -9
  110. flwr/server/serverapp/app.py +37 -13
  111. flwr/server/strategy/dpfedavg_fixed.py +1 -0
  112. flwr/server/superlink/driver/serverappio_grpc.py +2 -1
  113. flwr/server/superlink/driver/serverappio_servicer.py +148 -63
  114. flwr/server/superlink/ffs/disk_ffs.py +1 -0
  115. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -87
  116. flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -0
  117. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
  118. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +56 -35
  119. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +99 -169
  120. flwr/server/superlink/fleet/message_handler/message_handler.py +69 -29
  121. flwr/server/superlink/fleet/rest_rere/rest_api.py +20 -19
  122. flwr/server/superlink/fleet/vce/__init__.py +1 -0
  123. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -0
  124. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -0
  125. flwr/server/superlink/fleet/vce/vce_api.py +2 -2
  126. flwr/server/superlink/linkstate/in_memory_linkstate.py +60 -99
  127. flwr/server/superlink/linkstate/linkstate.py +30 -36
  128. flwr/server/superlink/linkstate/sqlite_linkstate.py +105 -188
  129. flwr/server/superlink/linkstate/utils.py +18 -8
  130. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  131. flwr/server/superlink/simulation/simulationio_servicer.py +33 -0
  132. flwr/server/superlink/utils.py +65 -0
  133. flwr/server/utils/validator.py +9 -34
  134. flwr/simulation/app.py +20 -10
  135. flwr/simulation/legacy_app.py +4 -2
  136. flwr/simulation/ray_transport/ray_actor.py +1 -0
  137. flwr/simulation/ray_transport/utils.py +1 -0
  138. flwr/simulation/run_simulation.py +36 -22
  139. flwr/simulation/simulationio_connection.py +5 -1
  140. flwr/superexec/app.py +1 -0
  141. flwr/superexec/deployment.py +1 -0
  142. flwr/superexec/exec_grpc.py +20 -2
  143. flwr/superexec/exec_servicer.py +97 -2
  144. flwr/superexec/exec_user_auth_interceptor.py +101 -0
  145. flwr/superexec/executor.py +1 -0
  146. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/METADATA +14 -13
  147. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/RECORD +150 -144
  148. flwr/proto/common_pb2.py +0 -36
  149. flwr/proto/common_pb2.pyi +0 -121
  150. flwr/proto/common_pb2_grpc.py +0 -4
  151. flwr/proto/common_pb2_grpc.pyi +0 -4
  152. flwr/proto/control_pb2.py +0 -27
  153. flwr/proto/control_pb2.pyi +0 -7
  154. flwr/proto/control_pb2_grpc.py +0 -135
  155. flwr/proto/control_pb2_grpc.pyi +0 -53
  156. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/LICENSE +0 -0
  157. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/WHEEL +0 -0
  158. {flwr-1.13.1.dist-info → flwr-1.15.0.dist-info}/entry_points.txt +0 -0
flwr/client/app.py CHANGED
@@ -14,13 +14,15 @@
14
14
  # ==============================================================================
15
15
  """Flower client app."""
16
16
 
17
- import signal
18
- import subprocess
17
+
18
+ import multiprocessing
19
+ import os
19
20
  import sys
21
+ import threading
20
22
  import time
21
23
  from contextlib import AbstractContextManager
22
- from dataclasses import dataclass
23
24
  from logging import ERROR, INFO, WARN
25
+ from os import urandom
24
26
  from pathlib import Path
25
27
  from typing import Callable, Optional, Union, cast
26
28
 
@@ -32,6 +34,7 @@ from flwr.cli.config_utils import get_fab_metadata
32
34
  from flwr.cli.install import install_from_fab
33
35
  from flwr.client.client import Client
34
36
  from flwr.client.client_app import ClientApp, LoadClientAppError
37
+ from flwr.client.clientapp.app import flwr_clientapp
35
38
  from flwr.client.nodestate.nodestate_factory import NodeStateFactory
36
39
  from flwr.client.typing import ClientFnExt
37
40
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
@@ -42,7 +45,6 @@ from flwr.common.constant import (
42
45
  ISOLATION_MODE_PROCESS,
43
46
  ISOLATION_MODE_SUBPROCESS,
44
47
  MAX_RETRY_DELAY,
45
- MISSING_EXTRA_REST,
46
48
  RUN_ID_NUM_BYTES,
47
49
  SERVER_OCTET,
48
50
  TRANSPORT_TYPE_GRPC_ADAPTER,
@@ -52,13 +54,13 @@ from flwr.common.constant import (
52
54
  TRANSPORT_TYPES,
53
55
  ErrorCode,
54
56
  )
57
+ from flwr.common.exit import ExitCode, flwr_exit
58
+ from flwr.common.grpc import generic_create_grpc_server
55
59
  from flwr.common.logger import log, warn_deprecated_feature
56
60
  from flwr.common.message import Error
57
61
  from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
58
- from flwr.common.typing import Fab, Run, UserConfig
62
+ from flwr.common.typing import Fab, Run, RunNotRunningException, UserConfig
59
63
  from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server
60
- from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
61
- from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes
62
64
 
63
65
  from .clientapp.clientappio_servicer import ClientAppInputs, ClientAppIoServicer
64
66
  from .grpc_adapter_client.connection import grpc_adapter
@@ -344,10 +346,7 @@ def start_client_internal(
344
346
  transport, server_address
345
347
  )
346
348
 
347
- app_state_tracker = _AppStateTracker()
348
-
349
349
  def _on_sucess(retry_state: RetryState) -> None:
350
- app_state_tracker.is_connected = True
351
350
  if retry_state.tries > 1:
352
351
  log(
353
352
  INFO,
@@ -357,7 +356,6 @@ def start_client_internal(
357
356
  )
358
357
 
359
358
  def _on_backoff(retry_state: RetryState) -> None:
360
- app_state_tracker.is_connected = False
361
359
  if retry_state.tries == 1:
362
360
  log(WARN, "Connection attempt failed, retrying...")
363
361
  else:
@@ -390,10 +388,11 @@ def start_client_internal(
390
388
  run_info_store: Optional[DeprecatedRunInfoStore] = None
391
389
  state_factory = NodeStateFactory()
392
390
  state = state_factory.state()
391
+ mp_spawn_context = multiprocessing.get_context("spawn")
393
392
 
394
393
  runs: dict[int, Run] = {}
395
394
 
396
- while not app_state_tracker.interrupt:
395
+ while True:
397
396
  sleep_duration: int = 0
398
397
  with connection(
399
398
  address,
@@ -432,9 +431,8 @@ def start_client_internal(
432
431
  node_config=node_config,
433
432
  )
434
433
 
435
- app_state_tracker.register_signal_handler()
436
434
  # pylint: disable=too-many-nested-blocks
437
- while not app_state_tracker.interrupt:
435
+ while True:
438
436
  try:
439
437
  # Receive
440
438
  message = receive()
@@ -474,7 +472,7 @@ def start_client_internal(
474
472
 
475
473
  run: Run = runs[run_id]
476
474
  if get_fab is not None and run.fab_hash:
477
- fab = get_fab(run.fab_hash)
475
+ fab = get_fab(run.fab_hash, run_id)
478
476
  if not isolation:
479
477
  # If `ClientApp` runs in the same process, install the FAB
480
478
  install_from_fab(fab.content, flwr_path, True)
@@ -512,7 +510,7 @@ def start_client_internal(
512
510
  # Docker container.
513
511
 
514
512
  # Generate SuperNode token
515
- token: int = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
513
+ token = int.from_bytes(urandom(RUN_ID_NUM_BYTES), "little")
516
514
 
517
515
  # Mode 1: SuperNode starts ClientApp as subprocess
518
516
  start_subprocess = isolation == ISOLATION_MODE_SUBPROCESS
@@ -548,12 +546,13 @@ def start_client_internal(
548
546
  ]
549
547
  command.append("--insecure")
550
548
 
551
- subprocess.run(
552
- command,
553
- stdout=None,
554
- stderr=None,
555
- check=True,
549
+ proc = mp_spawn_context.Process(
550
+ target=_run_flwr_clientapp,
551
+ args=(command, os.getpid()),
552
+ daemon=True,
556
553
  )
554
+ proc.start()
555
+ proc.join()
557
556
  else:
558
557
  # Wait for output to become available
559
558
  while not clientappio_servicer.has_outputs():
@@ -591,10 +590,7 @@ def start_client_internal(
591
590
  e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
592
591
  exc_entity = "SuperNode"
593
592
 
594
- if not app_state_tracker.interrupt:
595
- log(
596
- ERROR, "%s raised an exception", exc_entity, exc_info=ex
597
- )
593
+ log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
598
594
 
599
595
  # Create error message
600
596
  reply_message = message.create_error_reply(
@@ -611,18 +607,23 @@ def start_client_internal(
611
607
  send(reply_message)
612
608
  log(INFO, "Sent reply")
613
609
 
614
- except StopIteration:
615
- sleep_duration = 0
616
- break
610
+ except RunNotRunningException:
611
+ log(INFO, "")
612
+ log(
613
+ INFO,
614
+ "SuperNode aborted sending the reply message. "
615
+ "Run ID %s is not in `RUNNING` status.",
616
+ run_id,
617
+ )
618
+ log(INFO, "")
617
619
  # pylint: enable=too-many-nested-blocks
618
620
 
619
621
  # Unregister node
620
- if delete_node is not None and app_state_tracker.is_connected:
622
+ if delete_node is not None:
621
623
  delete_node() # pylint: disable=not-callable
622
624
 
623
625
  if sleep_duration == 0:
624
626
  log(INFO, "Disconnect and shut down")
625
- del app_state_tracker
626
627
  break
627
628
 
628
629
  # Sleep and reconnect afterwards
@@ -752,7 +753,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
752
753
  Optional[Callable[[], Optional[int]]],
753
754
  Optional[Callable[[], None]],
754
755
  Optional[Callable[[int], Run]],
755
- Optional[Callable[[str], Fab]],
756
+ Optional[Callable[[str, int], Fab]],
756
757
  ]
757
758
  ],
758
759
  ],
@@ -762,7 +763,10 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
762
763
  # Parse IP address
763
764
  parsed_address = parse_address(server_address)
764
765
  if not parsed_address:
765
- sys.exit(f"Server address ({server_address}) cannot be parsed.")
766
+ flwr_exit(
767
+ ExitCode.COMMON_ADDRESS_INVALID,
768
+ f"SuperLink address ({server_address}) cannot be parsed.",
769
+ )
766
770
  host, port, is_v6 = parsed_address
767
771
  address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
768
772
 
@@ -777,12 +781,9 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
777
781
 
778
782
  from .rest_client.connection import http_request_response
779
783
  except ModuleNotFoundError:
780
- sys.exit(MISSING_EXTRA_REST)
784
+ flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
781
785
  if server_address[:4] != "http":
782
- sys.exit(
783
- "When using the REST API, please provide `https://` or "
784
- "`http://` before the server address (e.g. `http://127.0.0.1:8080`)"
785
- )
786
+ flwr_exit(ExitCode.SUPERNODE_REST_ADDRESS_INVALID)
786
787
  connection, error_type = http_request_response, RequestsConnectionError
787
788
  elif transport == TRANSPORT_TYPE_GRPC_RERE:
788
789
  connection, error_type = grpc_request_response, RpcError
@@ -798,21 +799,19 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
798
799
  return connection, address, error_type
799
800
 
800
801
 
801
- @dataclass
802
- class _AppStateTracker:
803
- interrupt: bool = False
804
- is_connected: bool = False
805
-
806
- def register_signal_handler(self) -> None:
807
- """Register handlers for exit signals."""
802
+ def _run_flwr_clientapp(args: list[str], main_pid: int) -> None:
803
+ # Monitor the main process in case of SIGKILL
804
+ def main_process_monitor() -> None:
805
+ while True:
806
+ time.sleep(1)
807
+ if os.getppid() != main_pid:
808
+ os.kill(os.getpid(), 9)
808
809
 
809
- def signal_handler(sig, frame): # type: ignore
810
- # pylint: disable=unused-argument
811
- self.interrupt = True
812
- raise StopIteration from None
810
+ threading.Thread(target=main_process_monitor, daemon=True).start()
813
811
 
814
- signal.signal(signal.SIGINT, signal_handler)
815
- signal.signal(signal.SIGTERM, signal_handler)
812
+ # Run the command
813
+ sys.argv = args
814
+ flwr_clientapp()
816
815
 
817
816
 
818
817
  def run_clientappio_api_grpc(
flwr/client/client.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower client (abstract base class)."""
16
16
 
17
+
17
18
  # Needed to `Client` class can return a type of `Client` (not needed in py3.11+)
18
19
  from __future__ import annotations
19
20
 
@@ -21,7 +22,6 @@ from abc import ABC
21
22
 
22
23
  from flwr.common import (
23
24
  Code,
24
- Context,
25
25
  EvaluateIns,
26
26
  EvaluateRes,
27
27
  FitIns,
@@ -33,14 +33,11 @@ from flwr.common import (
33
33
  Parameters,
34
34
  Status,
35
35
  )
36
- from flwr.common.logger import warn_deprecated_feature_with_example
37
36
 
38
37
 
39
38
  class Client(ABC):
40
39
  """Abstract base class for Flower clients."""
41
40
 
42
- _context: Context
43
-
44
41
  def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
45
42
  """Return set of client's properties.
46
43
 
@@ -142,34 +139,6 @@ class Client(ABC):
142
139
  metrics={},
143
140
  )
144
141
 
145
- @property
146
- def context(self) -> Context:
147
- """Getter for `Context` client attribute."""
148
- warn_deprecated_feature_with_example(
149
- "Accessing the context via the client's attribute is deprecated.",
150
- example_message="Instead, pass it to the client's "
151
- "constructor in your `client_fn()` which already "
152
- "receives a context object.",
153
- code_example="def client_fn(context: Context) -> Client:\n\n"
154
- "\t\t# Your existing client_fn\n\n"
155
- "\t\t# Pass `context` to the constructor\n"
156
- "\t\treturn FlowerClient(context).to_client()",
157
- )
158
- return self._context
159
-
160
- @context.setter
161
- def context(self, context: Context) -> None:
162
- """Setter for `Context` client attribute."""
163
- self._context = context
164
-
165
- def get_context(self) -> Context:
166
- """Get the run context from this client."""
167
- return self.context
168
-
169
- def set_context(self, context: Context) -> None:
170
- """Apply a run context to this client."""
171
- self.context = context
172
-
173
142
  def to_client(self) -> Client:
174
143
  """Return client (itself)."""
175
144
  return self
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  """Flower ClientApp process."""
16
16
 
17
+
17
18
  import argparse
18
- import sys
19
19
  import time
20
20
  from logging import DEBUG, ERROR, INFO
21
21
  from typing import Optional
@@ -28,9 +28,11 @@ from flwr.common import Context, Message
28
28
  from flwr.common.args import add_args_flwr_app_common
29
29
  from flwr.common.config import get_flwr_dir
30
30
  from flwr.common.constant import CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS, ErrorCode
31
- from flwr.common.grpc import create_channel
31
+ from flwr.common.exit import ExitCode, flwr_exit
32
+ from flwr.common.grpc import create_channel, on_channel_state_change
32
33
  from flwr.common.logger import log
33
34
  from flwr.common.message import Error
35
+ from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
34
36
  from flwr.common.serde import (
35
37
  context_from_proto,
36
38
  context_to_proto,
@@ -59,18 +61,16 @@ def flwr_clientapp() -> None:
59
61
  """Run process-isolated Flower ClientApp."""
60
62
  args = _parse_args_run_flwr_clientapp().parse_args()
61
63
  if not args.insecure:
62
- log(
63
- ERROR,
64
- "flwr-clientapp does not support TLS yet. "
65
- "Please use the '--insecure' flag.",
64
+ flwr_exit(
65
+ ExitCode.COMMON_TLS_NOT_SUPPORTED,
66
+ "flwr-clientapp does not support TLS yet.",
66
67
  )
67
- sys.exit(1)
68
68
 
69
- log(INFO, "Starting Flower ClientApp")
69
+ log(INFO, "Start `flwr-clientapp` process")
70
70
  log(
71
71
  DEBUG,
72
- "Starting isolated `ClientApp` connected to SuperNode's ClientAppIo API at %s "
73
- "with token %s",
72
+ "`flwr-clientapp` will attempt to connect to SuperNode's "
73
+ "ClientAppIo API at %s with token %s",
74
74
  args.clientappio_api_address,
75
75
  args.token,
76
76
  )
@@ -83,11 +83,6 @@ def flwr_clientapp() -> None:
83
83
  )
84
84
 
85
85
 
86
- def on_channel_state_change(channel_connectivity: str) -> None:
87
- """Log channel connectivity."""
88
- log(DEBUG, channel_connectivity)
89
-
90
-
91
86
  def run_clientapp( # pylint: disable=R0914
92
87
  clientappio_api_address: str,
93
88
  run_once: bool,
@@ -105,9 +100,9 @@ def run_clientapp( # pylint: disable=R0914
105
100
 
106
101
  # Resolve directory where FABs are installed
107
102
  flwr_dir_ = get_flwr_dir(flwr_dir)
108
-
109
103
  try:
110
104
  stub = ClientAppIoStub(channel)
105
+ _wrap_stub(stub, _make_simple_grpc_retry_invoker())
111
106
 
112
107
  while True:
113
108
  # If token is not set, loop until token is received from SuperNode
@@ -116,11 +111,11 @@ def run_clientapp( # pylint: disable=R0914
116
111
  time.sleep(1)
117
112
 
118
113
  # Pull Message, Context, Run and (optional) FAB from SuperNode
119
- message, context, run, fab = pull_message(stub=stub, token=token)
114
+ message, context, run, fab = pull_clientappinputs(stub=stub, token=token)
120
115
 
121
116
  # Install FAB, if provided
122
117
  if fab:
123
- log(DEBUG, "Flower ClientApp starts FAB installation.")
118
+ log(DEBUG, "[flwr-clientapp] Start FAB installation.")
124
119
  install_from_fab(fab.content, flwr_dir=flwr_dir_, skip_prompt=True)
125
120
 
126
121
  load_client_app_fn = get_load_client_app_fn(
@@ -132,12 +127,14 @@ def run_clientapp( # pylint: disable=R0914
132
127
 
133
128
  try:
134
129
  # Load ClientApp
130
+ log(DEBUG, "[flwr-clientapp] Start `ClientApp` Loading.")
135
131
  client_app: ClientApp = load_client_app_fn(
136
132
  run.fab_id, run.fab_version, fab.hash_str if fab else ""
137
133
  )
138
134
 
139
135
  # Execute ClientApp
140
136
  reply_message = client_app(message=message, context=context)
137
+
141
138
  except Exception as ex: # pylint: disable=broad-exception-caught
142
139
  # Don't update/change NodeState
143
140
 
@@ -159,7 +156,7 @@ def run_clientapp( # pylint: disable=R0914
159
156
  )
160
157
 
161
158
  # Push Message and Context to SuperNode
162
- _ = push_message(
159
+ _ = push_clientappoutputs(
163
160
  stub=stub, token=token, message=reply_message, context=context
164
161
  )
165
162
 
@@ -182,7 +179,7 @@ def run_clientapp( # pylint: disable=R0914
182
179
 
183
180
  def get_token(stub: grpc.Channel) -> Optional[int]:
184
181
  """Get a token from SuperNode."""
185
- log(DEBUG, "Flower ClientApp process requests token")
182
+ log(DEBUG, "[flwr-clientapp] Request token")
186
183
  try:
187
184
  res: GetTokenResponse = stub.GetToken(GetTokenRequest())
188
185
  log(DEBUG, "[GetToken] Received token: %s", res.token)
@@ -195,11 +192,11 @@ def get_token(stub: grpc.Channel) -> Optional[int]:
195
192
  return None
196
193
 
197
194
 
198
- def pull_message(
195
+ def pull_clientappinputs(
199
196
  stub: grpc.Channel, token: int
200
197
  ) -> tuple[Message, Context, Run, Optional[Fab]]:
201
- """Pull message from SuperNode to ClientApp."""
202
- log(INFO, "Pulling ClientAppInputs for token %s", token)
198
+ """Pull ClientAppInputs from SuperNode."""
199
+ log(INFO, "[flwr-clientapp] Pull `ClientAppInputs` for token %s", token)
203
200
  try:
204
201
  res: PullClientAppInputsResponse = stub.PullClientAppInputs(
205
202
  PullClientAppInputsRequest(token=token)
@@ -214,11 +211,11 @@ def pull_message(
214
211
  raise e
215
212
 
216
213
 
217
- def push_message(
214
+ def push_clientappoutputs(
218
215
  stub: grpc.Channel, token: int, message: Message, context: Context
219
216
  ) -> PushClientAppOutputsResponse:
220
- """Push message to SuperNode from ClientApp."""
221
- log(INFO, "Pushing ClientAppOutputs for token %s", token)
217
+ """Push ClientAppOutputs to SuperNode."""
218
+ log(INFO, "[flwr-clientapp] Push `ClientAppOutputs` for token %s", token)
222
219
  proto_message = message_to_proto(message)
223
220
  proto_context = context_to_proto(context)
224
221
 
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Flower ClientApp loading utils."""
16
16
 
17
+
17
18
  from logging import DEBUG
18
19
  from pathlib import Path
19
20
  from typing import Callable, Optional
@@ -65,7 +66,7 @@ def get_load_client_app_fn(
65
66
  # `fab_hash` is not required since the app is loaded from `runtime_app_dir`.
66
67
  elif app_path is not None:
67
68
  config = get_project_config(runtime_app_dir)
68
- this_fab_version, this_fab_id = get_metadata_from_config(config)
69
+ this_fab_id, this_fab_version = get_metadata_from_config(config)
69
70
 
70
71
  if this_fab_version != fab_version or this_fab_id != fab_id:
71
72
  raise LoadClientAppError(
@@ -48,7 +48,7 @@ def grpc_adapter( # pylint: disable=R0913,too-many-positional-arguments
48
48
  Optional[Callable[[], Optional[int]]],
49
49
  Optional[Callable[[], None]],
50
50
  Optional[Callable[[int], Run]],
51
- Optional[Callable[[str], Fab]],
51
+ Optional[Callable[[str, int], Fab]],
52
52
  ]
53
53
  ]:
54
54
  """Primitives for request/response-based interaction with a server via GrpcAdapter.
@@ -36,7 +36,7 @@ from flwr.common import (
36
36
  from flwr.common import recordset_compat as compat
37
37
  from flwr.common import serde
38
38
  from flwr.common.constant import MessageType, MessageTypeLegacy
39
- from flwr.common.grpc import create_channel
39
+ from flwr.common.grpc import create_channel, on_channel_state_change
40
40
  from flwr.common.logger import log
41
41
  from flwr.common.retry_invoker import RetryInvoker
42
42
  from flwr.common.typing import Fab, Run
@@ -47,17 +47,6 @@ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
47
47
  )
48
48
  from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
49
49
 
50
- # The following flags can be uncommented for debugging. Other possible values:
51
- # https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
52
- # import os
53
- # os.environ["GRPC_VERBOSITY"] = "debug"
54
- # os.environ["GRPC_TRACE"] = "tcp,http"
55
-
56
-
57
- def on_channel_state_change(channel_connectivity: str) -> None:
58
- """Log channel connectivity."""
59
- log(DEBUG, channel_connectivity)
60
-
61
50
 
62
51
  @contextmanager
63
52
  def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-arguments
@@ -76,7 +65,7 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
76
65
  Optional[Callable[[], Optional[int]]],
77
66
  Optional[Callable[[], None]],
78
67
  Optional[Callable[[int], Run]],
79
- Optional[Callable[[str], Fab]],
68
+ Optional[Callable[[str, int], Fab]],
80
69
  ]
81
70
  ]:
82
71
  """Establish a gRPC connection to a gRPC server.
@@ -15,67 +15,18 @@
15
15
  """Flower client interceptor."""
16
16
 
17
17
 
18
- import base64
19
- import collections
20
- from collections.abc import Sequence
21
- from logging import WARNING
22
- from typing import Any, Callable, Optional, Union
18
+ from typing import Any, Callable
23
19
 
24
20
  import grpc
25
21
  from cryptography.hazmat.primitives.asymmetric import ec
22
+ from google.protobuf.message import Message as GrpcMessage
26
23
 
27
- from flwr.common.logger import log
24
+ from flwr.common import now
25
+ from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER
28
26
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
29
- bytes_to_public_key,
30
- compute_hmac,
31
- generate_shared_key,
32
27
  public_key_to_bytes,
28
+ sign_message,
33
29
  )
34
- from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
35
- from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
36
- CreateNodeRequest,
37
- DeleteNodeRequest,
38
- PingRequest,
39
- PullTaskInsRequest,
40
- PushTaskResRequest,
41
- )
42
- from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
43
-
44
- _PUBLIC_KEY_HEADER = "public-key"
45
- _AUTH_TOKEN_HEADER = "auth-token"
46
-
47
- Request = Union[
48
- CreateNodeRequest,
49
- DeleteNodeRequest,
50
- PullTaskInsRequest,
51
- PushTaskResRequest,
52
- GetRunRequest,
53
- PingRequest,
54
- GetFabRequest,
55
- ]
56
-
57
-
58
- def _get_value_from_tuples(
59
- key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]]
60
- ) -> bytes:
61
- value = next((value for key, value in tuples if key == key_string), "")
62
- if isinstance(value, str):
63
- return value.encode()
64
-
65
- return value
66
-
67
-
68
- class _ClientCallDetails(
69
- collections.namedtuple(
70
- "_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
71
- ),
72
- grpc.ClientCallDetails, # type: ignore
73
- ):
74
- """Details for each client call.
75
-
76
- The class will be passed on as the first argument in continuation function.
77
- In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
78
- """
79
30
 
80
31
 
81
32
  class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
@@ -87,84 +38,33 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
87
38
  public_key: ec.EllipticCurvePublicKey,
88
39
  ):
89
40
  self.private_key = private_key
90
- self.public_key = public_key
91
- self.shared_secret: Optional[bytes] = None
92
- self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
93
- self.encoded_public_key = base64.urlsafe_b64encode(
94
- public_key_to_bytes(self.public_key)
95
- )
41
+ self.public_key_bytes = public_key_to_bytes(public_key)
96
42
 
97
43
  def intercept_unary_unary(
98
44
  self,
99
45
  continuation: Callable[[Any, Any], Any],
100
46
  client_call_details: grpc.ClientCallDetails,
101
- request: Request,
47
+ request: GrpcMessage,
102
48
  ) -> grpc.Call:
103
49
  """Flower client interceptor.
104
50
 
105
51
  Intercept unary call from client and add necessary authentication header in the
106
52
  RPC metadata.
107
53
  """
108
- metadata = []
109
- postprocess = False
110
- if client_call_details.metadata is not None:
111
- metadata = list(client_call_details.metadata)
112
-
113
- # Always add the public key header
114
- metadata.append(
115
- (
116
- _PUBLIC_KEY_HEADER,
117
- self.encoded_public_key,
118
- )
119
- )
120
-
121
- if isinstance(request, CreateNodeRequest):
122
- postprocess = True
123
- elif isinstance(
124
- request,
125
- (
126
- DeleteNodeRequest,
127
- PullTaskInsRequest,
128
- PushTaskResRequest,
129
- GetRunRequest,
130
- PingRequest,
131
- GetFabRequest,
132
- ),
133
- ):
134
- if self.shared_secret is None:
135
- raise RuntimeError("Failure to compute hmac")
136
-
137
- message_bytes = request.SerializeToString(deterministic=True)
138
- metadata.append(
139
- (
140
- _AUTH_TOKEN_HEADER,
141
- base64.urlsafe_b64encode(
142
- compute_hmac(self.shared_secret, message_bytes)
143
- ),
144
- )
145
- )
54
+ metadata = list(client_call_details.metadata or [])
146
55
 
147
- client_call_details = _ClientCallDetails(
148
- client_call_details.method,
149
- client_call_details.timeout,
150
- metadata,
151
- client_call_details.credentials,
152
- )
56
+ # Add the public key
57
+ metadata.append((PUBLIC_KEY_HEADER, self.public_key_bytes))
153
58
 
154
- response = continuation(client_call_details, request)
155
- if postprocess:
156
- server_public_key_bytes = base64.urlsafe_b64decode(
157
- _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
158
- )
59
+ # Add timestamp
60
+ timestamp = now().isoformat()
61
+ metadata.append((TIMESTAMP_HEADER, timestamp))
159
62
 
160
- if server_public_key_bytes != b"":
161
- self.server_public_key = bytes_to_public_key(server_public_key_bytes)
162
- else:
163
- log(WARNING, "Can't get server public key, SuperLink may be offline")
63
+ # Sign and add the signature
64
+ signature = sign_message(self.private_key, timestamp.encode("ascii"))
65
+ metadata.append((SIGNATURE_HEADER, signature))
164
66
 
165
- if self.server_public_key is not None:
166
- self.shared_secret = generate_shared_key(
167
- self.private_key, self.server_public_key
168
- )
67
+ # Overwrite the metadata
68
+ details = client_call_details._replace(metadata=metadata)
169
69
 
170
- return response
70
+ return continuation(details, request)