flwr 1.14.0__py3-none-any.whl → 1.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. flwr/cli/auth_plugin/__init__.py +31 -0
  2. flwr/cli/auth_plugin/oidc_cli_plugin.py +150 -0
  3. flwr/cli/cli_user_auth_interceptor.py +6 -2
  4. flwr/cli/config_utils.py +24 -147
  5. flwr/cli/constant.py +27 -0
  6. flwr/cli/install.py +1 -1
  7. flwr/cli/log.py +18 -3
  8. flwr/cli/login/login.py +43 -8
  9. flwr/cli/ls.py +14 -5
  10. flwr/cli/new/templates/app/README.md.tpl +3 -2
  11. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  12. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  13. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  14. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +2 -2
  15. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  16. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +2 -2
  17. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +4 -4
  18. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +3 -3
  19. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
  20. flwr/cli/run/run.py +21 -11
  21. flwr/cli/stop.py +13 -4
  22. flwr/cli/utils.py +54 -40
  23. flwr/client/app.py +36 -48
  24. flwr/client/clientapp/app.py +19 -25
  25. flwr/client/clientapp/utils.py +1 -1
  26. flwr/client/grpc_client/connection.py +1 -12
  27. flwr/client/grpc_rere_client/client_interceptor.py +19 -119
  28. flwr/client/grpc_rere_client/connection.py +46 -36
  29. flwr/client/grpc_rere_client/grpc_adapter.py +12 -12
  30. flwr/client/message_handler/task_handler.py +0 -17
  31. flwr/client/rest_client/connection.py +34 -26
  32. flwr/client/supernode/app.py +18 -72
  33. flwr/common/args.py +25 -47
  34. flwr/common/auth_plugin/auth_plugin.py +34 -23
  35. flwr/common/config.py +166 -16
  36. flwr/common/constant.py +24 -9
  37. flwr/common/differential_privacy.py +2 -1
  38. flwr/common/exit/__init__.py +24 -0
  39. flwr/common/exit/exit.py +99 -0
  40. flwr/common/exit/exit_code.py +93 -0
  41. flwr/common/exit_handlers.py +32 -30
  42. flwr/common/grpc.py +167 -4
  43. flwr/common/logger.py +26 -7
  44. flwr/common/object_ref.py +0 -14
  45. flwr/common/record/recordset.py +1 -1
  46. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +45 -0
  47. flwr/common/serde.py +6 -4
  48. flwr/common/typing.py +20 -0
  49. flwr/proto/clientappio_pb2.py +1 -1
  50. flwr/proto/error_pb2.py +1 -1
  51. flwr/proto/exec_pb2.py +13 -25
  52. flwr/proto/exec_pb2.pyi +27 -54
  53. flwr/proto/fab_pb2.py +1 -1
  54. flwr/proto/fleet_pb2.py +31 -31
  55. flwr/proto/fleet_pb2.pyi +23 -23
  56. flwr/proto/fleet_pb2_grpc.py +30 -30
  57. flwr/proto/fleet_pb2_grpc.pyi +20 -20
  58. flwr/proto/grpcadapter_pb2.py +1 -1
  59. flwr/proto/log_pb2.py +1 -1
  60. flwr/proto/message_pb2.py +1 -1
  61. flwr/proto/node_pb2.py +3 -3
  62. flwr/proto/node_pb2.pyi +1 -4
  63. flwr/proto/recordset_pb2.py +1 -1
  64. flwr/proto/run_pb2.py +1 -1
  65. flwr/proto/serverappio_pb2.py +24 -25
  66. flwr/proto/serverappio_pb2.pyi +26 -32
  67. flwr/proto/serverappio_pb2_grpc.py +28 -28
  68. flwr/proto/serverappio_pb2_grpc.pyi +16 -16
  69. flwr/proto/simulationio_pb2.py +1 -1
  70. flwr/proto/task_pb2.py +1 -1
  71. flwr/proto/transport_pb2.py +1 -1
  72. flwr/server/app.py +116 -128
  73. flwr/server/compat/app_utils.py +0 -1
  74. flwr/server/compat/driver_client_proxy.py +1 -2
  75. flwr/server/driver/grpc_driver.py +32 -27
  76. flwr/server/driver/inmemory_driver.py +2 -1
  77. flwr/server/serverapp/app.py +12 -10
  78. flwr/server/superlink/driver/serverappio_grpc.py +1 -1
  79. flwr/server/superlink/driver/serverappio_servicer.py +74 -48
  80. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +20 -88
  81. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -165
  82. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -24
  83. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +110 -168
  84. flwr/server/superlink/fleet/message_handler/message_handler.py +37 -24
  85. flwr/server/superlink/fleet/rest_rere/rest_api.py +16 -18
  86. flwr/server/superlink/fleet/vce/vce_api.py +2 -2
  87. flwr/server/superlink/linkstate/in_memory_linkstate.py +45 -75
  88. flwr/server/superlink/linkstate/linkstate.py +17 -38
  89. flwr/server/superlink/linkstate/sqlite_linkstate.py +81 -145
  90. flwr/server/superlink/linkstate/utils.py +18 -8
  91. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  92. flwr/server/utils/validator.py +9 -34
  93. flwr/simulation/app.py +4 -6
  94. flwr/simulation/legacy_app.py +4 -2
  95. flwr/simulation/run_simulation.py +1 -1
  96. flwr/simulation/simulationio_connection.py +2 -1
  97. flwr/superexec/exec_grpc.py +1 -1
  98. flwr/superexec/exec_servicer.py +23 -2
  99. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/METADATA +8 -8
  100. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/RECORD +103 -97
  101. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/LICENSE +0 -0
  102. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/WHEEL +0 -0
  103. {flwr-1.14.0.dist-info → flwr-1.15.1.dist-info}/entry_points.txt +0 -0
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  import argparse
19
- import sys
20
19
  import time
21
20
  from logging import DEBUG, ERROR, INFO
22
21
  from typing import Optional
@@ -29,7 +28,8 @@ from flwr.common import Context, Message
29
28
  from flwr.common.args import add_args_flwr_app_common
30
29
  from flwr.common.config import get_flwr_dir
31
30
  from flwr.common.constant import CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS, ErrorCode
32
- from flwr.common.grpc import create_channel
31
+ from flwr.common.exit import ExitCode, flwr_exit
32
+ from flwr.common.grpc import create_channel, on_channel_state_change
33
33
  from flwr.common.logger import log
34
34
  from flwr.common.message import Error
35
35
  from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
@@ -61,18 +61,16 @@ def flwr_clientapp() -> None:
61
61
  """Run process-isolated Flower ClientApp."""
62
62
  args = _parse_args_run_flwr_clientapp().parse_args()
63
63
  if not args.insecure:
64
- log(
65
- ERROR,
66
- "flwr-clientapp does not support TLS yet. "
67
- "Please use the '--insecure' flag.",
64
+ flwr_exit(
65
+ ExitCode.COMMON_TLS_NOT_SUPPORTED,
66
+ "flwr-clientapp does not support TLS yet.",
68
67
  )
69
- sys.exit(1)
70
68
 
71
- log(INFO, "Starting Flower ClientApp")
69
+ log(INFO, "Start `flwr-clientapp` process")
72
70
  log(
73
71
  DEBUG,
74
- "Starting isolated `ClientApp` connected to SuperNode's ClientAppIo API at %s "
75
- "with token %s",
72
+ "`flwr-clientapp` will attempt to connect to SuperNode's "
73
+ "ClientAppIo API at %s with token %s",
76
74
  args.clientappio_api_address,
77
75
  args.token,
78
76
  )
@@ -85,11 +83,6 @@ def flwr_clientapp() -> None:
85
83
  )
86
84
 
87
85
 
88
- def on_channel_state_change(channel_connectivity: str) -> None:
89
- """Log channel connectivity."""
90
- log(DEBUG, channel_connectivity)
91
-
92
-
93
86
  def run_clientapp( # pylint: disable=R0914
94
87
  clientappio_api_address: str,
95
88
  run_once: bool,
@@ -118,11 +111,11 @@ def run_clientapp( # pylint: disable=R0914
118
111
  time.sleep(1)
119
112
 
120
113
  # Pull Message, Context, Run and (optional) FAB from SuperNode
121
- message, context, run, fab = pull_message(stub=stub, token=token)
114
+ message, context, run, fab = pull_clientappinputs(stub=stub, token=token)
122
115
 
123
116
  # Install FAB, if provided
124
117
  if fab:
125
- log(DEBUG, "Flower ClientApp starts FAB installation.")
118
+ log(DEBUG, "[flwr-clientapp] Start FAB installation.")
126
119
  install_from_fab(fab.content, flwr_dir=flwr_dir_, skip_prompt=True)
127
120
 
128
121
  load_client_app_fn = get_load_client_app_fn(
@@ -134,6 +127,7 @@ def run_clientapp( # pylint: disable=R0914
134
127
 
135
128
  try:
136
129
  # Load ClientApp
130
+ log(DEBUG, "[flwr-clientapp] Start `ClientApp` Loading.")
137
131
  client_app: ClientApp = load_client_app_fn(
138
132
  run.fab_id, run.fab_version, fab.hash_str if fab else ""
139
133
  )
@@ -162,7 +156,7 @@ def run_clientapp( # pylint: disable=R0914
162
156
  )
163
157
 
164
158
  # Push Message and Context to SuperNode
165
- _ = push_message(
159
+ _ = push_clientappoutputs(
166
160
  stub=stub, token=token, message=reply_message, context=context
167
161
  )
168
162
 
@@ -185,7 +179,7 @@ def run_clientapp( # pylint: disable=R0914
185
179
 
186
180
  def get_token(stub: grpc.Channel) -> Optional[int]:
187
181
  """Get a token from SuperNode."""
188
- log(DEBUG, "Flower ClientApp process requests token")
182
+ log(DEBUG, "[flwr-clientapp] Request token")
189
183
  try:
190
184
  res: GetTokenResponse = stub.GetToken(GetTokenRequest())
191
185
  log(DEBUG, "[GetToken] Received token: %s", res.token)
@@ -198,11 +192,11 @@ def get_token(stub: grpc.Channel) -> Optional[int]:
198
192
  return None
199
193
 
200
194
 
201
- def pull_message(
195
+ def pull_clientappinputs(
202
196
  stub: grpc.Channel, token: int
203
197
  ) -> tuple[Message, Context, Run, Optional[Fab]]:
204
- """Pull message from SuperNode to ClientApp."""
205
- log(INFO, "Pulling ClientAppInputs for token %s", token)
198
+ """Pull ClientAppInputs from SuperNode."""
199
+ log(INFO, "[flwr-clientapp] Pull `ClientAppInputs` for token %s", token)
206
200
  try:
207
201
  res: PullClientAppInputsResponse = stub.PullClientAppInputs(
208
202
  PullClientAppInputsRequest(token=token)
@@ -217,11 +211,11 @@ def pull_message(
217
211
  raise e
218
212
 
219
213
 
220
- def push_message(
214
+ def push_clientappoutputs(
221
215
  stub: grpc.Channel, token: int, message: Message, context: Context
222
216
  ) -> PushClientAppOutputsResponse:
223
- """Push message to SuperNode from ClientApp."""
224
- log(INFO, "Pushing ClientAppOutputs for token %s", token)
217
+ """Push ClientAppOutputs to SuperNode."""
218
+ log(INFO, "[flwr-clientapp] Push `ClientAppOutputs` for token %s", token)
225
219
  proto_message = message_to_proto(message)
226
220
  proto_context = context_to_proto(context)
227
221
 
@@ -66,7 +66,7 @@ def get_load_client_app_fn(
66
66
  # `fab_hash` is not required since the app is loaded from `runtime_app_dir`.
67
67
  elif app_path is not None:
68
68
  config = get_project_config(runtime_app_dir)
69
- this_fab_version, this_fab_id = get_metadata_from_config(config)
69
+ this_fab_id, this_fab_version = get_metadata_from_config(config)
70
70
 
71
71
  if this_fab_version != fab_version or this_fab_id != fab_id:
72
72
  raise LoadClientAppError(
@@ -36,7 +36,7 @@ from flwr.common import (
36
36
  from flwr.common import recordset_compat as compat
37
37
  from flwr.common import serde
38
38
  from flwr.common.constant import MessageType, MessageTypeLegacy
39
- from flwr.common.grpc import create_channel
39
+ from flwr.common.grpc import create_channel, on_channel_state_change
40
40
  from flwr.common.logger import log
41
41
  from flwr.common.retry_invoker import RetryInvoker
42
42
  from flwr.common.typing import Fab, Run
@@ -47,17 +47,6 @@ from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
47
47
  )
48
48
  from flwr.proto.transport_pb2_grpc import FlowerServiceStub # pylint: disable=E0611
49
49
 
50
- # The following flags can be uncommented for debugging. Other possible values:
51
- # https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
52
- # import os
53
- # os.environ["GRPC_VERBOSITY"] = "debug"
54
- # os.environ["GRPC_TRACE"] = "tcp,http"
55
-
56
-
57
- def on_channel_state_change(channel_connectivity: str) -> None:
58
- """Log channel connectivity."""
59
- log(DEBUG, channel_connectivity)
60
-
61
50
 
62
51
  @contextmanager
63
52
  def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-arguments
@@ -15,67 +15,18 @@
15
15
  """Flower client interceptor."""
16
16
 
17
17
 
18
- import base64
19
- import collections
20
- from collections.abc import Sequence
21
- from logging import WARNING
22
- from typing import Any, Callable, Optional, Union
18
+ from typing import Any, Callable
23
19
 
24
20
  import grpc
25
21
  from cryptography.hazmat.primitives.asymmetric import ec
22
+ from google.protobuf.message import Message as GrpcMessage
26
23
 
27
- from flwr.common.logger import log
24
+ from flwr.common import now
25
+ from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER
28
26
  from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
29
- bytes_to_public_key,
30
- compute_hmac,
31
- generate_shared_key,
32
27
  public_key_to_bytes,
28
+ sign_message,
33
29
  )
34
- from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
35
- from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
36
- CreateNodeRequest,
37
- DeleteNodeRequest,
38
- PingRequest,
39
- PullTaskInsRequest,
40
- PushTaskResRequest,
41
- )
42
- from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611
43
-
44
- _PUBLIC_KEY_HEADER = "public-key"
45
- _AUTH_TOKEN_HEADER = "auth-token"
46
-
47
- Request = Union[
48
- CreateNodeRequest,
49
- DeleteNodeRequest,
50
- PullTaskInsRequest,
51
- PushTaskResRequest,
52
- GetRunRequest,
53
- PingRequest,
54
- GetFabRequest,
55
- ]
56
-
57
-
58
- def _get_value_from_tuples(
59
- key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]]
60
- ) -> bytes:
61
- value = next((value for key, value in tuples if key == key_string), "")
62
- if isinstance(value, str):
63
- return value.encode()
64
-
65
- return value
66
-
67
-
68
- class _ClientCallDetails(
69
- collections.namedtuple(
70
- "_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
71
- ),
72
- grpc.ClientCallDetails, # type: ignore
73
- ):
74
- """Details for each client call.
75
-
76
- The class will be passed on as the first argument in continuation function.
77
- In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
78
- """
79
30
 
80
31
 
81
32
  class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
@@ -87,84 +38,33 @@ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type:
87
38
  public_key: ec.EllipticCurvePublicKey,
88
39
  ):
89
40
  self.private_key = private_key
90
- self.public_key = public_key
91
- self.shared_secret: Optional[bytes] = None
92
- self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
93
- self.encoded_public_key = base64.urlsafe_b64encode(
94
- public_key_to_bytes(self.public_key)
95
- )
41
+ self.public_key_bytes = public_key_to_bytes(public_key)
96
42
 
97
43
  def intercept_unary_unary(
98
44
  self,
99
45
  continuation: Callable[[Any, Any], Any],
100
46
  client_call_details: grpc.ClientCallDetails,
101
- request: Request,
47
+ request: GrpcMessage,
102
48
  ) -> grpc.Call:
103
49
  """Flower client interceptor.
104
50
 
105
51
  Intercept unary call from client and add necessary authentication header in the
106
52
  RPC metadata.
107
53
  """
108
- metadata = []
109
- postprocess = False
110
- if client_call_details.metadata is not None:
111
- metadata = list(client_call_details.metadata)
112
-
113
- # Always add the public key header
114
- metadata.append(
115
- (
116
- _PUBLIC_KEY_HEADER,
117
- self.encoded_public_key,
118
- )
119
- )
120
-
121
- if isinstance(request, CreateNodeRequest):
122
- postprocess = True
123
- elif isinstance(
124
- request,
125
- (
126
- DeleteNodeRequest,
127
- PullTaskInsRequest,
128
- PushTaskResRequest,
129
- GetRunRequest,
130
- PingRequest,
131
- GetFabRequest,
132
- ),
133
- ):
134
- if self.shared_secret is None:
135
- raise RuntimeError("Failure to compute hmac")
136
-
137
- message_bytes = request.SerializeToString(deterministic=True)
138
- metadata.append(
139
- (
140
- _AUTH_TOKEN_HEADER,
141
- base64.urlsafe_b64encode(
142
- compute_hmac(self.shared_secret, message_bytes)
143
- ),
144
- )
145
- )
54
+ metadata = list(client_call_details.metadata or [])
146
55
 
147
- client_call_details = _ClientCallDetails(
148
- client_call_details.method,
149
- client_call_details.timeout,
150
- metadata,
151
- client_call_details.credentials,
152
- )
56
+ # Add the public key
57
+ metadata.append((PUBLIC_KEY_HEADER, self.public_key_bytes))
153
58
 
154
- response = continuation(client_call_details, request)
155
- if postprocess:
156
- server_public_key_bytes = base64.urlsafe_b64decode(
157
- _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
158
- )
59
+ # Add timestamp
60
+ timestamp = now().isoformat()
61
+ metadata.append((TIMESTAMP_HEADER, timestamp))
159
62
 
160
- if server_public_key_bytes != b"":
161
- self.server_public_key = bytes_to_public_key(server_public_key_bytes)
162
- else:
163
- log(WARNING, "Can't get server public key, SuperLink may be offline")
63
+ # Sign and add the signature
64
+ signature = sign_message(self.private_key, timestamp.encode("ascii"))
65
+ metadata.append((SIGNATURE_HEADER, signature))
164
66
 
165
- if self.server_public_key is not None:
166
- self.shared_secret = generate_shared_key(
167
- self.private_key, self.server_public_key
168
- )
67
+ # Overwrite the metadata
68
+ details = client_call_details._replace(metadata=metadata)
169
69
 
170
- return response
70
+ return continuation(details, request)
@@ -20,7 +20,7 @@ import threading
20
20
  from collections.abc import Iterator, Sequence
21
21
  from contextlib import contextmanager
22
22
  from copy import copy
23
- from logging import DEBUG, ERROR
23
+ from logging import ERROR
24
24
  from pathlib import Path
25
25
  from typing import Callable, Optional, Union, cast
26
26
 
@@ -29,7 +29,6 @@ from cryptography.hazmat.primitives.asymmetric import ec
29
29
 
30
30
  from flwr.client.heartbeat import start_ping_loop
31
31
  from flwr.client.message_handler.message_handler import validate_out_message
32
- from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins
33
32
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
34
33
  from flwr.common.constant import (
35
34
  PING_BASE_MULTIPLIER,
@@ -37,11 +36,14 @@ from flwr.common.constant import (
37
36
  PING_DEFAULT_INTERVAL,
38
37
  PING_RANDOM_RANGE,
39
38
  )
40
- from flwr.common.grpc import create_channel
39
+ from flwr.common.grpc import create_channel, on_channel_state_change
41
40
  from flwr.common.logger import log
42
41
  from flwr.common.message import Message, Metadata
43
42
  from flwr.common.retry_invoker import RetryInvoker
44
- from flwr.common.serde import message_from_taskins, message_to_taskres, run_from_proto
43
+ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
44
+ generate_key_pairs,
45
+ )
46
+ from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
45
47
  from flwr.common.typing import Fab, Run, RunNotRunningException
46
48
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
47
49
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
@@ -49,23 +51,18 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
49
51
  DeleteNodeRequest,
50
52
  PingRequest,
51
53
  PingResponse,
52
- PullTaskInsRequest,
53
- PushTaskResRequest,
54
+ PullMessagesRequest,
55
+ PullMessagesResponse,
56
+ PushMessagesRequest,
54
57
  )
55
58
  from flwr.proto.fleet_pb2_grpc import FleetStub # pylint: disable=E0611
56
59
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
57
60
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
58
- from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
59
61
 
60
62
  from .client_interceptor import AuthenticateClientInterceptor
61
63
  from .grpc_adapter import GrpcAdapter
62
64
 
63
65
 
64
- def on_channel_state_change(channel_connectivity: str) -> None:
65
- """Log channel connectivity."""
66
- log(DEBUG, channel_connectivity)
67
-
68
-
69
66
  @contextmanager
70
67
  def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
71
68
  server_address: str,
@@ -131,12 +128,14 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
131
128
  if isinstance(root_certificates, str):
132
129
  root_certificates = Path(root_certificates).read_bytes()
133
130
 
134
- interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None
135
- if authentication_keys is not None:
136
- interceptors = AuthenticateClientInterceptor(
137
- authentication_keys[0], authentication_keys[1]
138
- )
131
+ # Automatic node auth: generate keys if user didn't provide any
132
+ if authentication_keys is None:
133
+ authentication_keys = generate_key_pairs()
139
134
 
135
+ # Always configure auth interceptor, with either user-provided or generated keys
136
+ interceptors: Sequence[grpc.UnaryUnaryClientInterceptor] = [
137
+ AuthenticateClientInterceptor(*authentication_keys),
138
+ ]
140
139
  channel = create_channel(
141
140
  server_address=server_address,
142
141
  insecure=insecure,
@@ -227,28 +226,31 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
227
226
  node = None
228
227
 
229
228
  def receive() -> Optional[Message]:
230
- """Receive next task from server."""
229
+ """Receive next message from server."""
231
230
  # Get Node
232
231
  if node is None:
233
232
  log(ERROR, "Node instance missing")
234
233
  return None
235
234
 
236
- # Request instructions (task) from server
237
- request = PullTaskInsRequest(node=node)
238
- response = retry_invoker.invoke(stub.PullTaskIns, request=request)
235
+ # Request instructions (message) from server
236
+ request = PullMessagesRequest(node=node)
237
+ response: PullMessagesResponse = retry_invoker.invoke(
238
+ stub.PullMessages, request=request
239
+ )
239
240
 
240
- # Get the current TaskIns
241
- task_ins: Optional[TaskIns] = get_task_ins(response)
241
+ # Get the current Messages
242
+ message_proto = (
243
+ None if len(response.messages_list) == 0 else response.messages_list[0]
244
+ )
242
245
 
243
- # Discard the current TaskIns if not valid
244
- if task_ins is not None and not (
245
- task_ins.task.consumer.node_id == node.node_id
246
- and validate_task_ins(task_ins)
246
+ # Discard the current message if not valid
247
+ if message_proto is not None and not (
248
+ message_proto.metadata.dst_node_id == node.node_id
247
249
  ):
248
- task_ins = None
250
+ message_proto = None
249
251
 
250
252
  # Construct the Message
251
- in_message = message_from_taskins(task_ins) if task_ins else None
253
+ in_message = message_from_proto(message_proto) if message_proto else None
252
254
 
253
255
  # Remember `metadata` of the in message
254
256
  nonlocal metadata
@@ -258,7 +260,7 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
258
260
  return in_message
259
261
 
260
262
  def send(message: Message) -> None:
261
- """Send task result back to server."""
263
+ """Send message reply to server."""
262
264
  # Get Node
263
265
  if node is None:
264
266
  log(ERROR, "Node instance missing")
@@ -275,12 +277,10 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
275
277
  log(ERROR, "Invalid out message")
276
278
  return
277
279
 
278
- # Construct TaskRes
279
- task_res = message_to_taskres(message)
280
-
281
- # Serialize ProtoBuf to bytes
282
- request = PushTaskResRequest(node=node, task_res_list=[task_res])
283
- _ = retry_invoker.invoke(stub.PushTaskRes, request)
280
+ # Serialize Message
281
+ message_proto = message_to_proto(message=message)
282
+ request = PushMessagesRequest(node=node, messages_list=[message_proto])
283
+ _ = retry_invoker.invoke(stub.PushMessages, request)
284
284
 
285
285
  # Cleanup
286
286
  metadata = None
@@ -311,3 +311,13 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
311
311
  yield (receive, send, create_node, delete_node, get_run, get_fab)
312
312
  except Exception as exc: # pylint: disable=broad-except
313
313
  log(ERROR, exc)
314
+ # Cleanup
315
+ finally:
316
+ try:
317
+ if node is not None:
318
+ # Disable retrying
319
+ retry_invoker.max_tries = 1
320
+ delete_node()
321
+ except grpc.RpcError:
322
+ pass
323
+ channel.close()
@@ -40,10 +40,10 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
40
40
  DeleteNodeResponse,
41
41
  PingRequest,
42
42
  PingResponse,
43
- PullTaskInsRequest,
44
- PullTaskInsResponse,
45
- PushTaskResRequest,
46
- PushTaskResResponse,
43
+ PullMessagesRequest,
44
+ PullMessagesResponse,
45
+ PushMessagesRequest,
46
+ PushMessagesResponse,
47
47
  )
48
48
  from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
49
49
  from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub
@@ -126,17 +126,17 @@ class GrpcAdapter:
126
126
  """."""
127
127
  return self._send_and_receive(request, PingResponse, **kwargs)
128
128
 
129
- def PullTaskIns( # pylint: disable=C0103
130
- self, request: PullTaskInsRequest, **kwargs: Any
131
- ) -> PullTaskInsResponse:
129
+ def PullMessages( # pylint: disable=C0103
130
+ self, request: PullMessagesRequest, **kwargs: Any
131
+ ) -> PullMessagesResponse:
132
132
  """."""
133
- return self._send_and_receive(request, PullTaskInsResponse, **kwargs)
133
+ return self._send_and_receive(request, PullMessagesResponse, **kwargs)
134
134
 
135
- def PushTaskRes( # pylint: disable=C0103
136
- self, request: PushTaskResRequest, **kwargs: Any
137
- ) -> PushTaskResResponse:
135
+ def PushMessages( # pylint: disable=C0103
136
+ self, request: PushMessagesRequest, **kwargs: Any
137
+ ) -> PushMessagesResponse:
138
138
  """."""
139
- return self._send_and_receive(request, PushTaskResResponse, **kwargs)
139
+ return self._send_and_receive(request, PushMessagesResponse, **kwargs)
140
140
 
141
141
  def GetRun( # pylint: disable=C0103
142
142
  self, request: GetRunRequest, **kwargs: Any
@@ -15,9 +15,6 @@
15
15
  """Task handling."""
16
16
 
17
17
 
18
- from typing import Optional
19
-
20
- from flwr.proto.fleet_pb2 import PullTaskInsResponse # pylint: disable=E0611
21
18
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
22
19
 
23
20
 
@@ -38,17 +35,3 @@ def validate_task_ins(task_ins: TaskIns) -> bool:
38
35
  if not (task_ins.HasField("task") and task_ins.task.HasField("recordset")):
39
36
  return False
40
37
  return True
41
-
42
-
43
- def get_task_ins(
44
- pull_task_ins_response: PullTaskInsResponse,
45
- ) -> Optional[TaskIns]:
46
- """Get the first TaskIns, if available."""
47
- # Extract a single ServerMessage from the response, if possible
48
- if len(pull_task_ins_response.task_ins_list) == 0:
49
- return None
50
-
51
- # Only evaluate the first message
52
- task_ins: TaskIns = pull_task_ins_response.task_ins_list[0]
53
-
54
- return task_ins