flwr-nightly 1.15.0.dev20250104__py3-none-any.whl → 1.15.0.dev20250123__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. flwr/cli/cli_user_auth_interceptor.py +6 -2
  2. flwr/cli/config_utils.py +23 -146
  3. flwr/cli/constant.py +27 -0
  4. flwr/cli/install.py +1 -1
  5. flwr/cli/log.py +17 -2
  6. flwr/cli/login/login.py +20 -5
  7. flwr/cli/ls.py +10 -2
  8. flwr/cli/run/run.py +20 -10
  9. flwr/cli/stop.py +9 -1
  10. flwr/cli/utils.py +4 -4
  11. flwr/client/app.py +36 -48
  12. flwr/client/clientapp/app.py +4 -6
  13. flwr/client/clientapp/utils.py +1 -1
  14. flwr/client/grpc_client/connection.py +0 -6
  15. flwr/client/grpc_rere_client/client_interceptor.py +19 -119
  16. flwr/client/grpc_rere_client/connection.py +34 -24
  17. flwr/client/grpc_rere_client/grpc_adapter.py +16 -0
  18. flwr/client/rest_client/connection.py +34 -26
  19. flwr/client/supernode/app.py +14 -20
  20. flwr/common/auth_plugin/auth_plugin.py +34 -23
  21. flwr/common/config.py +152 -15
  22. flwr/common/constant.py +11 -8
  23. flwr/common/exit/__init__.py +24 -0
  24. flwr/common/exit/exit.py +99 -0
  25. flwr/common/exit/exit_code.py +93 -0
  26. flwr/common/exit_handlers.py +24 -10
  27. flwr/common/grpc.py +161 -3
  28. flwr/common/logger.py +1 -1
  29. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
  30. flwr/common/serde.py +6 -4
  31. flwr/common/typing.py +20 -0
  32. flwr/proto/clientappio_pb2.py +13 -3
  33. flwr/proto/clientappio_pb2_grpc.py +63 -12
  34. flwr/proto/error_pb2.py +13 -3
  35. flwr/proto/error_pb2_grpc.py +20 -0
  36. flwr/proto/exec_pb2.py +27 -29
  37. flwr/proto/exec_pb2.pyi +27 -54
  38. flwr/proto/exec_pb2_grpc.py +105 -24
  39. flwr/proto/fab_pb2.py +13 -3
  40. flwr/proto/fab_pb2_grpc.py +20 -0
  41. flwr/proto/fleet_pb2.py +54 -31
  42. flwr/proto/fleet_pb2.pyi +84 -0
  43. flwr/proto/fleet_pb2_grpc.py +207 -28
  44. flwr/proto/fleet_pb2_grpc.pyi +26 -0
  45. flwr/proto/grpcadapter_pb2.py +14 -4
  46. flwr/proto/grpcadapter_pb2_grpc.py +35 -4
  47. flwr/proto/log_pb2.py +13 -3
  48. flwr/proto/log_pb2_grpc.py +20 -0
  49. flwr/proto/message_pb2.py +15 -5
  50. flwr/proto/message_pb2_grpc.py +20 -0
  51. flwr/proto/node_pb2.py +15 -5
  52. flwr/proto/node_pb2.pyi +1 -4
  53. flwr/proto/node_pb2_grpc.py +20 -0
  54. flwr/proto/recordset_pb2.py +18 -8
  55. flwr/proto/recordset_pb2_grpc.py +20 -0
  56. flwr/proto/run_pb2.py +16 -6
  57. flwr/proto/run_pb2_grpc.py +20 -0
  58. flwr/proto/serverappio_pb2.py +32 -14
  59. flwr/proto/serverappio_pb2.pyi +56 -0
  60. flwr/proto/serverappio_pb2_grpc.py +261 -44
  61. flwr/proto/serverappio_pb2_grpc.pyi +20 -0
  62. flwr/proto/simulationio_pb2.py +13 -3
  63. flwr/proto/simulationio_pb2_grpc.py +105 -24
  64. flwr/proto/task_pb2.py +13 -3
  65. flwr/proto/task_pb2_grpc.py +20 -0
  66. flwr/proto/transport_pb2.py +20 -10
  67. flwr/proto/transport_pb2_grpc.py +35 -4
  68. flwr/server/app.py +87 -38
  69. flwr/server/compat/app_utils.py +0 -1
  70. flwr/server/compat/driver_client_proxy.py +1 -2
  71. flwr/server/driver/grpc_driver.py +5 -2
  72. flwr/server/driver/inmemory_driver.py +2 -1
  73. flwr/server/serverapp/app.py +5 -6
  74. flwr/server/superlink/driver/serverappio_grpc.py +1 -1
  75. flwr/server/superlink/driver/serverappio_servicer.py +132 -14
  76. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
  77. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
  78. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +38 -0
  79. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +95 -168
  80. flwr/server/superlink/fleet/message_handler/message_handler.py +66 -5
  81. flwr/server/superlink/fleet/rest_rere/rest_api.py +28 -3
  82. flwr/server/superlink/fleet/vce/vce_api.py +2 -2
  83. flwr/server/superlink/linkstate/in_memory_linkstate.py +40 -48
  84. flwr/server/superlink/linkstate/linkstate.py +15 -22
  85. flwr/server/superlink/linkstate/sqlite_linkstate.py +80 -99
  86. flwr/server/superlink/linkstate/utils.py +18 -8
  87. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  88. flwr/server/utils/validator.py +9 -34
  89. flwr/simulation/app.py +4 -6
  90. flwr/simulation/legacy_app.py +4 -2
  91. flwr/simulation/run_simulation.py +1 -1
  92. flwr/superexec/exec_grpc.py +1 -1
  93. flwr/superexec/exec_servicer.py +23 -2
  94. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/METADATA +7 -7
  95. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/RECORD +98 -94
  96. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/LICENSE +0 -0
  97. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/WHEEL +0 -0
  98. {flwr_nightly-1.15.0.dev20250104.dist-info → flwr_nightly-1.15.0.dev20250123.dist-info}/entry_points.txt +0 -0
flwr/client/app.py CHANGED
@@ -15,13 +15,14 @@
15
15
  """Flower client app."""
16
16
 
17
17
 
18
- import signal
19
- import subprocess
18
+ import multiprocessing
19
+ import os
20
20
  import sys
21
+ import threading
21
22
  import time
22
23
  from contextlib import AbstractContextManager
23
- from dataclasses import dataclass
24
24
  from logging import ERROR, INFO, WARN
25
+ from os import urandom
25
26
  from pathlib import Path
26
27
  from typing import Callable, Optional, Union, cast
27
28
 
@@ -33,6 +34,7 @@ from flwr.cli.config_utils import get_fab_metadata
33
34
  from flwr.cli.install import install_from_fab
34
35
  from flwr.client.client import Client
35
36
  from flwr.client.client_app import ClientApp, LoadClientAppError
37
+ from flwr.client.clientapp.app import flwr_clientapp
36
38
  from flwr.client.nodestate.nodestate_factory import NodeStateFactory
37
39
  from flwr.client.typing import ClientFnExt
38
40
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
@@ -43,7 +45,6 @@ from flwr.common.constant import (
43
45
  ISOLATION_MODE_PROCESS,
44
46
  ISOLATION_MODE_SUBPROCESS,
45
47
  MAX_RETRY_DELAY,
46
- MISSING_EXTRA_REST,
47
48
  RUN_ID_NUM_BYTES,
48
49
  SERVER_OCTET,
49
50
  TRANSPORT_TYPE_GRPC_ADAPTER,
@@ -53,13 +54,13 @@ from flwr.common.constant import (
53
54
  TRANSPORT_TYPES,
54
55
  ErrorCode,
55
56
  )
57
+ from flwr.common.exit import ExitCode, flwr_exit
58
+ from flwr.common.grpc import generic_create_grpc_server
56
59
  from flwr.common.logger import log, warn_deprecated_feature
57
60
  from flwr.common.message import Error
58
61
  from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
59
62
  from flwr.common.typing import Fab, Run, RunNotRunningException, UserConfig
60
63
  from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server
61
- from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
62
- from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes
63
64
 
64
65
  from .clientapp.clientappio_servicer import ClientAppInputs, ClientAppIoServicer
65
66
  from .grpc_adapter_client.connection import grpc_adapter
@@ -345,10 +346,7 @@ def start_client_internal(
345
346
  transport, server_address
346
347
  )
347
348
 
348
- app_state_tracker = _AppStateTracker()
349
-
350
349
  def _on_sucess(retry_state: RetryState) -> None:
351
- app_state_tracker.is_connected = True
352
350
  if retry_state.tries > 1:
353
351
  log(
354
352
  INFO,
@@ -358,7 +356,6 @@ def start_client_internal(
358
356
  )
359
357
 
360
358
  def _on_backoff(retry_state: RetryState) -> None:
361
- app_state_tracker.is_connected = False
362
359
  if retry_state.tries == 1:
363
360
  log(WARN, "Connection attempt failed, retrying...")
364
361
  else:
@@ -391,10 +388,11 @@ def start_client_internal(
391
388
  run_info_store: Optional[DeprecatedRunInfoStore] = None
392
389
  state_factory = NodeStateFactory()
393
390
  state = state_factory.state()
391
+ mp_spawn_context = multiprocessing.get_context("spawn")
394
392
 
395
393
  runs: dict[int, Run] = {}
396
394
 
397
- while not app_state_tracker.interrupt:
395
+ while True:
398
396
  sleep_duration: int = 0
399
397
  with connection(
400
398
  address,
@@ -433,9 +431,8 @@ def start_client_internal(
433
431
  node_config=node_config,
434
432
  )
435
433
 
436
- app_state_tracker.register_signal_handler()
437
434
  # pylint: disable=too-many-nested-blocks
438
- while not app_state_tracker.interrupt:
435
+ while True:
439
436
  try:
440
437
  # Receive
441
438
  message = receive()
@@ -513,7 +510,7 @@ def start_client_internal(
513
510
  # Docker container.
514
511
 
515
512
  # Generate SuperNode token
516
- token: int = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
513
+ token = int.from_bytes(urandom(RUN_ID_NUM_BYTES), "little")
517
514
 
518
515
  # Mode 1: SuperNode starts ClientApp as subprocess
519
516
  start_subprocess = isolation == ISOLATION_MODE_SUBPROCESS
@@ -549,12 +546,13 @@ def start_client_internal(
549
546
  ]
550
547
  command.append("--insecure")
551
548
 
552
- subprocess.run(
553
- command,
554
- stdout=None,
555
- stderr=None,
556
- check=True,
549
+ proc = mp_spawn_context.Process(
550
+ target=_run_flwr_clientapp,
551
+ args=(command, os.getpid()),
552
+ daemon=True,
557
553
  )
554
+ proc.start()
555
+ proc.join()
558
556
  else:
559
557
  # Wait for output to become available
560
558
  while not clientappio_servicer.has_outputs():
@@ -592,10 +590,7 @@ def start_client_internal(
592
590
  e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
593
591
  exc_entity = "SuperNode"
594
592
 
595
- if not app_state_tracker.interrupt:
596
- log(
597
- ERROR, "%s raised an exception", exc_entity, exc_info=ex
598
- )
593
+ log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
599
594
 
600
595
  # Create error message
601
596
  reply_message = message.create_error_reply(
@@ -621,19 +616,14 @@ def start_client_internal(
621
616
  run_id,
622
617
  )
623
618
  log(INFO, "")
624
-
625
- except StopIteration:
626
- sleep_duration = 0
627
- break
628
619
  # pylint: enable=too-many-nested-blocks
629
620
 
630
621
  # Unregister node
631
- if delete_node is not None and app_state_tracker.is_connected:
622
+ if delete_node is not None:
632
623
  delete_node() # pylint: disable=not-callable
633
624
 
634
625
  if sleep_duration == 0:
635
626
  log(INFO, "Disconnect and shut down")
636
- del app_state_tracker
637
627
  break
638
628
 
639
629
  # Sleep and reconnect afterwards
@@ -773,7 +763,10 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
773
763
  # Parse IP address
774
764
  parsed_address = parse_address(server_address)
775
765
  if not parsed_address:
776
- sys.exit(f"Server address ({server_address}) cannot be parsed.")
766
+ flwr_exit(
767
+ ExitCode.COMMON_ADDRESS_INVALID,
768
+ f"SuperLink address ({server_address}) cannot be parsed.",
769
+ )
777
770
  host, port, is_v6 = parsed_address
778
771
  address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
779
772
 
@@ -788,12 +781,9 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
788
781
 
789
782
  from .rest_client.connection import http_request_response
790
783
  except ModuleNotFoundError:
791
- sys.exit(MISSING_EXTRA_REST)
784
+ flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
792
785
  if server_address[:4] != "http":
793
- sys.exit(
794
- "When using the REST API, please provide `https://` or "
795
- "`http://` before the server address (e.g. `http://127.0.0.1:8080`)"
796
- )
786
+ flwr_exit(ExitCode.SUPERNODE_REST_ADDRESS_INVALID)
797
787
  connection, error_type = http_request_response, RequestsConnectionError
798
788
  elif transport == TRANSPORT_TYPE_GRPC_RERE:
799
789
  connection, error_type = grpc_request_response, RpcError
@@ -809,21 +799,19 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
809
799
  return connection, address, error_type
810
800
 
811
801
 
812
- @dataclass
813
- class _AppStateTracker:
814
- interrupt: bool = False
815
- is_connected: bool = False
816
-
817
- def register_signal_handler(self) -> None:
818
- """Register handlers for exit signals."""
802
+ def _run_flwr_clientapp(args: list[str], main_pid: int) -> None:
803
+ # Monitor the main process in case of SIGKILL
804
+ def main_process_monitor() -> None:
805
+ while True:
806
+ time.sleep(1)
807
+ if os.getppid() != main_pid:
808
+ os.kill(os.getpid(), 9)
819
809
 
820
- def signal_handler(sig, frame): # type: ignore
821
- # pylint: disable=unused-argument
822
- self.interrupt = True
823
- raise StopIteration from None
810
+ threading.Thread(target=main_process_monitor, daemon=True).start()
824
811
 
825
- signal.signal(signal.SIGINT, signal_handler)
826
- signal.signal(signal.SIGTERM, signal_handler)
812
+ # Run the command
813
+ sys.argv = args
814
+ flwr_clientapp()
827
815
 
828
816
 
829
817
  def run_clientappio_api_grpc(
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  import argparse
19
- import sys
20
19
  import time
21
20
  from logging import DEBUG, ERROR, INFO
22
21
  from typing import Optional
@@ -29,6 +28,7 @@ from flwr.common import Context, Message
29
28
  from flwr.common.args import add_args_flwr_app_common
30
29
  from flwr.common.config import get_flwr_dir
31
30
  from flwr.common.constant import CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS, ErrorCode
31
+ from flwr.common.exit import ExitCode, flwr_exit
32
32
  from flwr.common.grpc import create_channel
33
33
  from flwr.common.logger import log
34
34
  from flwr.common.message import Error
@@ -61,12 +61,10 @@ def flwr_clientapp() -> None:
61
61
  """Run process-isolated Flower ClientApp."""
62
62
  args = _parse_args_run_flwr_clientapp().parse_args()
63
63
  if not args.insecure:
64
- log(
65
- ERROR,
66
- "flwr-clientapp does not support TLS yet. "
67
- "Please use the '--insecure' flag.",
64
+ flwr_exit(
65
+ ExitCode.COMMON_TLS_NOT_SUPPORTED,
66
+ "flwr-clientapp does not support TLS yet.",
68
67
  )
69
- sys.exit(1)
70
68
 
71
69
  log(INFO, "Starting Flower ClientApp")
72
70
  log(
@@ -66,7 +66,7 @@ def get_load_client_app_fn(
66
66
  # `fab_hash` is not required since the app is loaded from `runtime_app_dir`.
67
67
  elif app_path is not None:
68
68
  config = get_project_config(runtime_app_dir)
69
- this_fab_version, this_fab_id = get_metadata_from_config(config)
69
+ this_fab_id, this_fab_version = get_metadata_from_config(config)
70
70
 
71
71
  if this_fab_version != fab_version or this_fab_id != fab_id:
72
72
  raise LoadClientAppError(
@@ -47,12 +47,6 @@ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
47
47
  )
48
48
  from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
49
49
 
50
- # The following flags can be uncommented for debugging. Other possible values:
51
- # https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
52
- # import os
53
- # os.environ["GRPC_VERBOSITY"] = "debug"
54
- # os.environ["GRPC_TRACE"] = "tcp,http"
55
-
56
50
 
57
51
  def on_channel_state_change(channel_connectivity: str) -> None:
58
52
  """Log channel connectivity."""
@@ -15,67 +15,18 @@
15
15
  """Flower client interceptor."""
16
16
 
17
17
 
18
- import base64
19
- import collections
20
- from collections.abc import Sequence
21
- from logging import WARNING
22
- from typing import Any, Callable, Optional, Union
18
+ from typing import Any, Callable
23
19
 
24
20
  import grpc
25
21
  from cryptography.hazmat.primitives.asymmetric import ec
22
+ from google.protobuf.message import Message as GrpcMessage
26
23
 
27
- from flwr.common.logger import log
24
+ from flwr.common import now
25
+ from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER
28
26
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
29
- bytes_to_public_key,
30
- compute_hmac,
31
- generate_shared_key,
32
27
  public_key_to_bytes,
28
+ sign_message,
33
29
  )
34
- from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
35
- from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
36
- CreateNodeRequest,
37
- DeleteNodeRequest,
38
- PingRequest,
39
- PullTaskInsRequest,
40
- PushTaskResRequest,
41
- )
42
- from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
43
-
44
- _PUBLIC_KEY_HEADER = "public-key"
45
- _AUTH_TOKEN_HEADER = "auth-token"
46
-
47
- Request = Union[
48
- CreateNodeRequest,
49
- DeleteNodeRequest,
50
- PullTaskInsRequest,
51
- PushTaskResRequest,
52
- GetRunRequest,
53
- PingRequest,
54
- GetFabRequest,
55
- ]
56
-
57
-
58
- def _get_value_from_tuples(
59
- key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]]
60
- ) -> bytes:
61
- value = next((value for key, value in tuples if key == key_string), "")
62
- if isinstance(value, str):
63
- return value.encode()
64
-
65
- return value
66
-
67
-
68
- class _ClientCallDetails(
69
- collections.namedtuple(
70
- "_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
71
- ),
72
- grpc.ClientCallDetails, # type: ignore
73
- ):
74
- """Details for each client call.
75
-
76
- The class will be passed on as the first argument in continuation function.
77
- In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
78
- """
79
30
 
80
31
 
81
32
  class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
@@ -87,84 +38,33 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
87
38
  public_key: ec.EllipticCurvePublicKey,
88
39
  ):
89
40
  self.private_key = private_key
90
- self.public_key = public_key
91
- self.shared_secret: Optional[bytes] = None
92
- self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
93
- self.encoded_public_key = base64.urlsafe_b64encode(
94
- public_key_to_bytes(self.public_key)
95
- )
41
+ self.public_key_bytes = public_key_to_bytes(public_key)
96
42
 
97
43
  def intercept_unary_unary(
98
44
  self,
99
45
  continuation: Callable[[Any, Any], Any],
100
46
  client_call_details: grpc.ClientCallDetails,
101
- request: Request,
47
+ request: GrpcMessage,
102
48
  ) -> grpc.Call:
103
49
  """Flower client interceptor.
104
50
 
105
51
  Intercept unary call from client and add necessary authentication header in the
106
52
  RPC metadata.
107
53
  """
108
- metadata = []
109
- postprocess = False
110
- if client_call_details.metadata is not None:
111
- metadata = list(client_call_details.metadata)
112
-
113
- # Always add the public key header
114
- metadata.append(
115
- (
116
- _PUBLIC_KEY_HEADER,
117
- self.encoded_public_key,
118
- )
119
- )
120
-
121
- if isinstance(request, CreateNodeRequest):
122
- postprocess = True
123
- elif isinstance(
124
- request,
125
- (
126
- DeleteNodeRequest,
127
- PullTaskInsRequest,
128
- PushTaskResRequest,
129
- GetRunRequest,
130
- PingRequest,
131
- GetFabRequest,
132
- ),
133
- ):
134
- if self.shared_secret is None:
135
- raise RuntimeError("Failure to compute hmac")
136
-
137
- message_bytes = request.SerializeToString(deterministic=True)
138
- metadata.append(
139
- (
140
- _AUTH_TOKEN_HEADER,
141
- base64.urlsafe_b64encode(
142
- compute_hmac(self.shared_secret, message_bytes)
143
- ),
144
- )
145
- )
54
+ metadata = list(client_call_details.metadata or [])
146
55
 
147
- client_call_details = _ClientCallDetails(
148
- client_call_details.method,
149
- client_call_details.timeout,
150
- metadata,
151
- client_call_details.credentials,
152
- )
56
+ # Add the public key
57
+ metadata.append((PUBLIC_KEY_HEADER, self.public_key_bytes))
153
58
 
154
- response = continuation(client_call_details, request)
155
- if postprocess:
156
- server_public_key_bytes = base64.urlsafe_b64decode(
157
- _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
158
- )
59
+ # Add timestamp
60
+ timestamp = now().isoformat()
61
+ metadata.append((TIMESTAMP_HEADER, timestamp))
159
62
 
160
- if server_public_key_bytes != b"":
161
- self.server_public_key = bytes_to_public_key(server_public_key_bytes)
162
- else:
163
- log(WARNING, "Can't get server public key, SuperLink may be offline")
63
+ # Sign and add the signature
64
+ signature = sign_message(self.private_key, timestamp.encode("ascii"))
65
+ metadata.append((SIGNATURE_HEADER, signature))
164
66
 
165
- if self.server_public_key is not None:
166
- self.shared_secret = generate_shared_key(
167
- self.private_key, self.server_public_key
168
- )
67
+ # Overwrite the metadata
68
+ details = client_call_details._replace(metadata=metadata)
169
69
 
170
- return response
70
+ return continuation(details, request)
@@ -29,7 +29,6 @@ from cryptography.hazmat.primitives.asymmetric import ec
29
29
 
30
30
  from flwr.client.heartbeat import start_ping_loop
31
31
  from flwr.client.message_handler.message_handler import validate_out_message
32
- from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins
33
32
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
34
33
  from flwr.common.constant import (
35
34
  PING_BASE_MULTIPLIER,
@@ -41,7 +40,7 @@ from flwr.common.grpc import create_channel
41
40
  from flwr.common.logger import log
42
41
  from flwr.common.message import Message, Metadata
43
42
  from flwr.common.retry_invoker import RetryInvoker
44
- from flwr.common.serde import message_from_taskins, message_to_taskres, run_from_proto
43
+ from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
45
44
  from flwr.common.typing import Fab, Run, RunNotRunningException
46
45
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
47
46
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
@@ -49,13 +48,13 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
49
48
  DeleteNodeRequest,
50
49
  PingRequest,
51
50
  PingResponse,
52
- PullTaskInsRequest,
53
- PushTaskResRequest,
51
+ PullMessagesRequest,
52
+ PullMessagesResponse,
53
+ PushMessagesRequest,
54
54
  )
55
55
  from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
56
56
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
57
57
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
58
- from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
59
58
 
60
59
  from .client_interceptor import AuthenticateClientInterceptor
61
60
  from .grpc_adapter import GrpcAdapter
@@ -227,28 +226,31 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
227
226
  node = None
228
227
 
229
228
  def receive() -> Optional[Message]:
230
- """Receive next task from server."""
229
+ """Receive next message from server."""
231
230
  # Get Node
232
231
  if node is None:
233
232
  log(ERROR, "Node instance missing")
234
233
  return None
235
234
 
236
- # Request instructions (task) from server
237
- request = PullTaskInsRequest(node=node)
238
- response = retry_invoker.invoke(stub.PullTaskIns, request=request)
235
+ # Request instructions (message) from server
236
+ request = PullMessagesRequest(node=node)
237
+ response: PullMessagesResponse = retry_invoker.invoke(
238
+ stub.PullMessages, request=request
239
+ )
239
240
 
240
- # Get the current TaskIns
241
- task_ins: Optional[TaskIns] = get_task_ins(response)
241
+ # Get the current Messages
242
+ message_proto = (
243
+ None if len(response.messages_list) == 0 else response.messages_list[0]
244
+ )
242
245
 
243
- # Discard the current TaskIns if not valid
244
- if task_ins is not None and not (
245
- task_ins.task.consumer.node_id == node.node_id
246
- and validate_task_ins(task_ins)
246
+ # Discard the current message if not valid
247
+ if message_proto is not None and not (
248
+ message_proto.metadata.dst_node_id == node.node_id
247
249
  ):
248
- task_ins = None
250
+ message_proto = None
249
251
 
250
252
  # Construct the Message
251
- in_message = message_from_taskins(task_ins) if task_ins else None
253
+ in_message = message_from_proto(message_proto) if message_proto else None
252
254
 
253
255
  # Remember `metadata` of the in message
254
256
  nonlocal metadata
@@ -258,7 +260,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
258
260
  return in_message
259
261
 
260
262
  def send(message: Message) -> None:
261
- """Send task result back to server."""
263
+ """Send message reply to server."""
262
264
  # Get Node
263
265
  if node is None:
264
266
  log(ERROR, "Node instance missing")
@@ -275,12 +277,10 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
275
277
  log(ERROR, "Invalid out message")
276
278
  return
277
279
 
278
- # Construct TaskRes
279
- task_res = message_to_taskres(message)
280
-
281
- # Serialize ProtoBuf to bytes
282
- request = PushTaskResRequest(node=node, task_res_list=[task_res])
283
- _ = retry_invoker.invoke(stub.PushTaskRes, request)
280
+ # Serialize Message
281
+ message_proto = message_to_proto(message=message)
282
+ request = PushMessagesRequest(node=node, messages_list=[message_proto])
283
+ _ = retry_invoker.invoke(stub.PushMessages, request)
284
284
 
285
285
  # Cleanup
286
286
  metadata = None
@@ -311,3 +311,13 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
311
311
  yield (receive, send, create_node, delete_node, get_run, get_fab)
312
312
  except Exception as exc: # pylint: disable=broad-except
313
313
  log(ERROR, exc)
314
+ # Cleanup
315
+ finally:
316
+ try:
317
+ if node is not None:
318
+ # Disable retrying
319
+ retry_invoker.max_tries = 1
320
+ delete_node()
321
+ except grpc.RpcError:
322
+ pass
323
+ channel.close()
@@ -40,8 +40,12 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
40
40
  DeleteNodeResponse,
41
41
  PingRequest,
42
42
  PingResponse,
43
+ PullMessagesRequest,
44
+ PullMessagesResponse,
43
45
  PullTaskInsRequest,
44
46
  PullTaskInsResponse,
47
+ PushMessagesRequest,
48
+ PushMessagesResponse,
45
49
  PushTaskResRequest,
46
50
  PushTaskResResponse,
47
51
  )
@@ -132,12 +136,24 @@ class GrpcAdapter:
132
136
  """."""
133
137
  return self._send_and_receive(request, PullTaskInsResponse, **kwargs)
134
138
 
139
+ def PullMessages( # pylint: disable=C0103
140
+ self, request: PullMessagesRequest, **kwargs: Any
141
+ ) -> PullMessagesResponse:
142
+ """."""
143
+ return self._send_and_receive(request, PullMessagesResponse, **kwargs)
144
+
135
145
  def PushTaskRes( # pylint: disable=C0103
136
146
  self, request: PushTaskResRequest, **kwargs: Any
137
147
  ) -> PushTaskResResponse:
138
148
  """."""
139
149
  return self._send_and_receive(request, PushTaskResResponse, **kwargs)
140
150
 
151
+ def PushMessages( # pylint: disable=C0103
152
+ self, request: PushMessagesRequest, **kwargs: Any
153
+ ) -> PushMessagesResponse:
154
+ """."""
155
+ return self._send_and_receive(request, PushMessagesResponse, **kwargs)
156
+
141
157
  def GetRun( # pylint: disable=C0103
142
158
  self, request: GetRunRequest, **kwargs: Any
143
159
  ) -> GetRunResponse: