flwr-nightly 1.9.0.dev20240531__py3-none-any.whl → 1.10.0.dev20240619__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (80) hide show
  1. flwr/cli/app.py +2 -0
  2. flwr/cli/build.py +4 -15
  3. flwr/cli/config_utils.py +64 -7
  4. flwr/cli/install.py +211 -0
  5. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  12. flwr/cli/run/run.py +39 -2
  13. flwr/cli/utils.py +14 -0
  14. flwr/client/__init__.py +1 -0
  15. flwr/client/app.py +153 -103
  16. flwr/client/client_app.py +1 -1
  17. flwr/client/grpc_adapter_client/__init__.py +15 -0
  18. flwr/client/grpc_adapter_client/connection.py +94 -0
  19. flwr/client/grpc_client/connection.py +5 -1
  20. flwr/client/grpc_rere_client/client_interceptor.py +1 -1
  21. flwr/client/grpc_rere_client/connection.py +9 -5
  22. flwr/client/grpc_rere_client/grpc_adapter.py +133 -0
  23. flwr/client/mod/__init__.py +4 -4
  24. flwr/client/rest_client/connection.py +10 -3
  25. flwr/client/supernode/app.py +155 -31
  26. flwr/common/__init__.py +12 -12
  27. flwr/common/config.py +71 -0
  28. flwr/common/constant.py +15 -0
  29. flwr/common/object_ref.py +52 -14
  30. flwr/common/record/__init__.py +1 -1
  31. flwr/common/telemetry.py +4 -0
  32. flwr/common/typing.py +9 -0
  33. flwr/proto/driver_pb2.py +20 -19
  34. flwr/proto/driver_pb2_grpc.py +35 -0
  35. flwr/proto/driver_pb2_grpc.pyi +14 -0
  36. flwr/proto/exec_pb2.py +34 -0
  37. flwr/proto/exec_pb2.pyi +55 -0
  38. flwr/proto/exec_pb2_grpc.py +101 -0
  39. flwr/proto/exec_pb2_grpc.pyi +41 -0
  40. flwr/proto/fab_pb2.py +30 -0
  41. flwr/proto/fab_pb2.pyi +56 -0
  42. flwr/proto/fab_pb2_grpc.py +4 -0
  43. flwr/proto/fab_pb2_grpc.pyi +4 -0
  44. flwr/proto/fleet_pb2.py +28 -33
  45. flwr/proto/fleet_pb2.pyi +0 -42
  46. flwr/proto/fleet_pb2_grpc.py +7 -6
  47. flwr/proto/fleet_pb2_grpc.pyi +5 -4
  48. flwr/proto/run_pb2.py +30 -0
  49. flwr/proto/run_pb2.pyi +52 -0
  50. flwr/proto/run_pb2_grpc.py +4 -0
  51. flwr/proto/run_pb2_grpc.pyi +4 -0
  52. flwr/server/__init__.py +2 -6
  53. flwr/server/app.py +94 -214
  54. flwr/server/run_serverapp.py +33 -7
  55. flwr/server/server_app.py +2 -2
  56. flwr/server/strategy/__init__.py +2 -2
  57. flwr/server/superlink/driver/driver_servicer.py +7 -0
  58. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  59. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  60. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +4 -0
  61. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -2
  62. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +1 -2
  63. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -6
  64. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -1
  65. flwr/server/superlink/fleet/vce/vce_api.py +3 -1
  66. flwr/server/superlink/state/in_memory_state.py +8 -5
  67. flwr/server/superlink/state/sqlite_state.py +6 -3
  68. flwr/server/superlink/state/state.py +5 -4
  69. flwr/simulation/__init__.py +4 -1
  70. flwr/simulation/run_simulation.py +22 -0
  71. flwr/superexec/__init__.py +21 -0
  72. flwr/superexec/app.py +178 -0
  73. flwr/superexec/exec_grpc.py +51 -0
  74. flwr/superexec/exec_servicer.py +65 -0
  75. flwr/superexec/executor.py +54 -0
  76. {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/METADATA +1 -1
  77. {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/RECORD +80 -56
  78. {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/entry_points.txt +1 -2
  79. {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/LICENSE +0 -0
  80. {flwr_nightly-1.9.0.dev20240531.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/WHEEL +0 -0
@@ -0,0 +1,133 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """GrpcAdapter implementation."""
16
+
17
+
18
+ import sys
19
+ from logging import DEBUG
20
+ from typing import Any, Type, TypeVar, cast
21
+
22
+ import grpc
23
+ from google.protobuf.message import Message as GrpcMessage
24
+
25
+ from flwr.common import log
26
+ from flwr.common.constant import (
27
+ GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY,
28
+ GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY,
29
+ )
30
+ from flwr.common.version import package_version
31
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
32
+ CreateNodeRequest,
33
+ CreateNodeResponse,
34
+ DeleteNodeRequest,
35
+ DeleteNodeResponse,
36
+ PingRequest,
37
+ PingResponse,
38
+ PullTaskInsRequest,
39
+ PullTaskInsResponse,
40
+ PushTaskResRequest,
41
+ PushTaskResResponse,
42
+ )
43
+ from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
44
+ from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub
45
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
46
+
47
+ T = TypeVar("T", bound=GrpcMessage)
48
+
49
+
50
+ class GrpcAdapter:
51
+ """Adapter class to send and receive gRPC messages via the ``GrpcAdapterStub``.
52
+
53
+ This class utilizes the ``GrpcAdapterStub`` to send and receive gRPC messages
54
+ which are defined and used by the Fleet API, as defined in ``fleet.proto``.
55
+ """
56
+
57
+ def __init__(self, channel: grpc.Channel) -> None:
58
+ self.stub = GrpcAdapterStub(channel)
59
+
60
+ def _send_and_receive(
61
+ self, request: GrpcMessage, response_type: Type[T], **kwargs: Any
62
+ ) -> T:
63
+ # Serialize request
64
+ container_req = MessageContainer(
65
+ metadata={GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY: package_version},
66
+ grpc_message_name=request.__class__.__qualname__,
67
+ grpc_message_content=request.SerializeToString(),
68
+ )
69
+
70
+ # Send via the stub
71
+ container_res = cast(
72
+ MessageContainer, self.stub.SendReceive(container_req, **kwargs)
73
+ )
74
+
75
+ # Handle control message
76
+ should_exit = (
77
+ container_res.metadata.get(GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY, "false")
78
+ == "true"
79
+ )
80
+ if should_exit:
81
+ log(
82
+ DEBUG,
83
+ 'Received shutdown signal: exit flag is set to ``"true"``. Exiting...',
84
+ )
85
+ sys.exit(0)
86
+
87
+ # Check the grpc_message_name of the response
88
+ if container_res.grpc_message_name != response_type.__qualname__:
89
+ raise ValueError(
90
+ f"Invalid grpc_message_name. Expected {response_type.__qualname__}"
91
+ f", but got {container_res.grpc_message_name}."
92
+ )
93
+
94
+ # Deserialize response
95
+ response = response_type()
96
+ response.ParseFromString(container_res.grpc_message_content)
97
+ return response
98
+
99
+ def CreateNode( # pylint: disable=C0103
100
+ self, request: CreateNodeRequest, **kwargs: Any
101
+ ) -> CreateNodeResponse:
102
+ """."""
103
+ return self._send_and_receive(request, CreateNodeResponse, **kwargs)
104
+
105
+ def DeleteNode( # pylint: disable=C0103
106
+ self, request: DeleteNodeRequest, **kwargs: Any
107
+ ) -> DeleteNodeResponse:
108
+ """."""
109
+ return self._send_and_receive(request, DeleteNodeResponse, **kwargs)
110
+
111
+ def Ping( # pylint: disable=C0103
112
+ self, request: PingRequest, **kwargs: Any
113
+ ) -> PingResponse:
114
+ """."""
115
+ return self._send_and_receive(request, PingResponse, **kwargs)
116
+
117
+ def PullTaskIns( # pylint: disable=C0103
118
+ self, request: PullTaskInsRequest, **kwargs: Any
119
+ ) -> PullTaskInsResponse:
120
+ """."""
121
+ return self._send_and_receive(request, PullTaskInsResponse, **kwargs)
122
+
123
+ def PushTaskRes( # pylint: disable=C0103
124
+ self, request: PushTaskResRequest, **kwargs: Any
125
+ ) -> PushTaskResResponse:
126
+ """."""
127
+ return self._send_and_receive(request, PushTaskResResponse, **kwargs)
128
+
129
+ def GetRun( # pylint: disable=C0103
130
+ self, request: GetRunRequest, **kwargs: Any
131
+ ) -> GetRunResponse:
132
+ """."""
133
+ return self._send_and_receive(request, GetRunResponse, **kwargs)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Mods."""
15
+ """Flower Built-in Mods."""
16
16
 
17
17
 
18
18
  from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
@@ -22,12 +22,12 @@ from .secure_aggregation import secagg_mod, secaggplus_mod
22
22
  from .utils import make_ffn
23
23
 
24
24
  __all__ = [
25
+ "LocalDpMod",
25
26
  "adaptiveclipping_mod",
26
27
  "fixedclipping_mod",
27
- "LocalDpMod",
28
28
  "make_ffn",
29
- "secagg_mod",
30
- "secaggplus_mod",
31
29
  "message_size_mod",
32
30
  "parameters_size_mod",
31
+ "secagg_mod",
32
+ "secaggplus_mod",
33
33
  ]
@@ -46,8 +46,6 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
46
46
  CreateNodeResponse,
47
47
  DeleteNodeRequest,
48
48
  DeleteNodeResponse,
49
- GetRunRequest,
50
- GetRunResponse,
51
49
  PingRequest,
52
50
  PingResponse,
53
51
  PullTaskInsRequest,
@@ -56,6 +54,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
56
54
  PushTaskResResponse,
57
55
  )
58
56
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
57
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
59
58
  from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
60
59
 
61
60
  try:
@@ -118,10 +117,16 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
118
117
  Path of the root certificate. If provided, a secure
119
118
  connection using the certificates will be established to an SSL-enabled
120
119
  Flower server. Bytes won't work for the REST API.
120
+ authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
121
+ Client authentication is not supported for this transport type.
121
122
 
122
123
  Returns
123
124
  -------
124
- receive, send : Callable, Callable
125
+ receive : Callable
126
+ send : Callable
127
+ create_node : Optional[Callable]
128
+ delete_node : Optional[Callable]
129
+ get_run : Optional[Callable]
125
130
  """
126
131
  log(
127
132
  WARN,
@@ -146,6 +151,8 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
146
151
  "For the REST API, the root certificates "
147
152
  "must be provided as a string path to the client.",
148
153
  )
154
+ if authentication_keys is not None:
155
+ log(ERROR, "Client authentication is not supported for this transport type.")
149
156
 
150
157
  # Shared variables for inner functions
151
158
  metadata: Optional[Metadata] = None
@@ -29,12 +29,20 @@ from cryptography.hazmat.primitives.serialization import (
29
29
 
30
30
  from flwr.client.client_app import ClientApp, LoadClientAppError
31
31
  from flwr.common import EventType, event
32
+ from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
33
+ from flwr.common.constant import (
34
+ TRANSPORT_TYPE_GRPC_ADAPTER,
35
+ TRANSPORT_TYPE_GRPC_RERE,
36
+ TRANSPORT_TYPE_REST,
37
+ )
32
38
  from flwr.common.exit_handlers import register_exit_handlers
33
- from flwr.common.logger import log
39
+ from flwr.common.logger import log, warn_deprecated_feature
34
40
  from flwr.common.object_ref import load_app, validate
35
41
 
36
42
  from ..app import _start_client_internal
37
43
 
44
+ ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092"
45
+
38
46
 
39
47
  def run_supernode() -> None:
40
48
  """Run Flower SuperNode."""
@@ -42,11 +50,23 @@ def run_supernode() -> None:
42
50
 
43
51
  event(EventType.RUN_SUPERNODE_ENTER)
44
52
 
45
- _ = _parse_args_run_supernode().parse_args()
53
+ args = _parse_args_run_supernode().parse_args()
54
+
55
+ _warn_deprecated_server_arg(args)
56
+
57
+ root_certificates = _get_certificates(args)
58
+ load_fn = _get_load_client_app_fn(args, multi_app=True)
59
+ authentication_keys = _try_setup_client_authentication(args)
46
60
 
47
- log(
48
- DEBUG,
49
- "Flower SuperNode starting...",
61
+ _start_client_internal(
62
+ server_address=args.superlink,
63
+ load_client_app_fn=load_fn,
64
+ transport=args.transport,
65
+ root_certificates=root_certificates,
66
+ insecure=args.insecure,
67
+ authentication_keys=authentication_keys,
68
+ max_retries=args.max_retries,
69
+ max_wait_time=args.max_wait_time,
50
70
  )
51
71
 
52
72
  # Graceful shutdown
@@ -63,19 +83,16 @@ def run_client_app() -> None:
63
83
 
64
84
  args = _parse_args_run_client_app().parse_args()
65
85
 
86
+ _warn_deprecated_server_arg(args)
87
+
66
88
  root_certificates = _get_certificates(args)
67
- log(
68
- DEBUG,
69
- "Flower will load ClientApp `%s`",
70
- getattr(args, "client-app"),
71
- )
72
- load_fn = _get_load_client_app_fn(args)
89
+ load_fn = _get_load_client_app_fn(args, multi_app=False)
73
90
  authentication_keys = _try_setup_client_authentication(args)
74
91
 
75
92
  _start_client_internal(
76
- server_address=args.server,
93
+ server_address=args.superlink,
77
94
  load_client_app_fn=load_fn,
78
- transport="rest" if args.rest else "grpc-rere",
95
+ transport=args.transport,
79
96
  root_certificates=root_certificates,
80
97
  insecure=args.insecure,
81
98
  authentication_keys=authentication_keys,
@@ -85,6 +102,26 @@ def run_client_app() -> None:
85
102
  register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
86
103
 
87
104
 
105
+ def _warn_deprecated_server_arg(args: argparse.Namespace) -> None:
106
+ """Warn about the deprecated argument `--server`."""
107
+ if args.server != ADDRESS_FLEET_API_GRPC_RERE:
108
+ warn = "Passing flag --server is deprecated. Use --superlink instead."
109
+ warn_deprecated_feature(warn)
110
+
111
+ if args.superlink != ADDRESS_FLEET_API_GRPC_RERE:
112
+ # if `--superlink` also passed, then
113
+ # warn user that this argument overrides what was passed with `--server`
114
+ log(
115
+ WARN,
116
+ "Both `--server` and `--superlink` were passed. "
117
+ "`--server` will be ignored. Connecting to the Superlink Fleet API "
118
+ "at %s.",
119
+ args.superlink,
120
+ )
121
+ else:
122
+ args.superlink = args.server
123
+
124
+
88
125
  def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
89
126
  """Load certificates if specified in args."""
90
127
  # Obtain certificates
@@ -100,7 +137,7 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
100
137
  WARN,
101
138
  "Option `--insecure` was set. "
102
139
  "Starting insecure HTTP client connected to %s.",
103
- args.server,
140
+ args.superlink,
104
141
  )
105
142
  root_certificates = None
106
143
  else:
@@ -114,31 +151,95 @@ def _get_certificates(args: argparse.Namespace) -> Optional[bytes]:
114
151
  DEBUG,
115
152
  "Starting secure HTTPS client connected to %s "
116
153
  "with the following certificates: %s.",
117
- args.server,
154
+ args.superlink,
118
155
  cert_path,
119
156
  )
120
157
  return root_certificates
121
158
 
122
159
 
123
160
  def _get_load_client_app_fn(
124
- args: argparse.Namespace,
125
- ) -> Callable[[], ClientApp]:
126
- """Get the load_client_app_fn function."""
127
- client_app_dir = args.dir
128
- if client_app_dir is not None:
129
- sys.path.insert(0, client_app_dir)
161
+ args: argparse.Namespace, multi_app: bool
162
+ ) -> Callable[[str, str], ClientApp]:
163
+ """Get the load_client_app_fn function.
164
+
165
+ If `multi_app` is True, this function loads the specified ClientApp
166
+ based on `fab_id` and `fab_version`. If `fab_id` is empty, a default
167
+ ClientApp will be loaded.
168
+
169
+ If `multi_app` is False, it ignores `fab_id` and `fab_version` and
170
+ loads a default ClientApp.
171
+ """
172
+ # Find the Flower directory containing Flower Apps (only for multi-app)
173
+ flwr_dir = Path("")
174
+ if "flwr_dir" in args:
175
+ if args.flwr_dir is None:
176
+ flwr_dir = get_flwr_dir()
177
+ else:
178
+ flwr_dir = Path(args.flwr_dir).absolute()
179
+
180
+ sys.path.insert(0, str(flwr_dir.absolute()))
181
+
182
+ default_app_ref: str = getattr(args, "client-app")
183
+
184
+ if not multi_app:
185
+ log(
186
+ DEBUG,
187
+ "Flower SuperNode will load and validate ClientApp `%s`",
188
+ getattr(args, "client-app"),
189
+ )
190
+ valid, error_msg = validate(default_app_ref)
191
+ if not valid and error_msg:
192
+ raise LoadClientAppError(error_msg) from None
193
+
194
+ def _load(fab_id: str, fab_version: str) -> ClientApp:
195
+ # If multi-app feature is disabled
196
+ if not multi_app:
197
+ # Get sys path to be inserted
198
+ sys_path = Path(args.dir).absolute()
199
+
200
+ # Set app reference
201
+ client_app_ref = default_app_ref
202
+ # If multi-app feature is enabled but the fab id is not specified
203
+ elif fab_id == "":
204
+ if default_app_ref == "":
205
+ raise LoadClientAppError(
206
+ "Invalid FAB ID: The FAB ID is empty.",
207
+ ) from None
208
+
209
+ log(WARN, "FAB ID is not provided; the default ClientApp will be loaded.")
210
+ # Get sys path to be inserted
211
+ sys_path = Path(args.dir).absolute()
212
+
213
+ # Set app reference
214
+ client_app_ref = default_app_ref
215
+ # If multi-app feature is enabled
216
+ else:
217
+ try:
218
+ project_dir = get_project_dir(fab_id, fab_version, flwr_dir)
219
+ config = get_project_config(project_dir)
220
+ except Exception as e:
221
+ raise LoadClientAppError("Failed to load ClientApp") from e
130
222
 
131
- app_ref: str = getattr(args, "client-app")
132
- valid, error_msg = validate(app_ref)
133
- if not valid and error_msg:
134
- raise LoadClientAppError(error_msg) from None
223
+ # Get sys path to be inserted
224
+ sys_path = Path(project_dir).absolute()
135
225
 
136
- def _load() -> ClientApp:
137
- client_app = load_app(app_ref, LoadClientAppError)
226
+ # Set app reference
227
+ client_app_ref = config["flower"]["components"]["clientapp"]
228
+
229
+ # Set sys.path
230
+ sys.path.insert(0, str(sys_path))
231
+
232
+ # Load ClientApp
233
+ log(
234
+ DEBUG,
235
+ "Loading ClientApp `%s`",
236
+ client_app_ref,
237
+ )
238
+ client_app = load_app(client_app_ref, LoadClientAppError, sys_path)
138
239
 
139
240
  if not isinstance(client_app, ClientApp):
140
241
  raise LoadClientAppError(
141
- f"Attribute {app_ref} is not of type {ClientApp}",
242
+ f"Attribute {client_app_ref} is not of type {ClientApp}",
142
243
  ) from None
143
244
 
144
245
  return client_app
@@ -199,9 +300,27 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
199
300
  help="Run the client without HTTPS. By default, the client runs with "
200
301
  "HTTPS enabled. Use this flag only if you understand the risks.",
201
302
  )
202
- parser.add_argument(
303
+ ex_group = parser.add_mutually_exclusive_group()
304
+ ex_group.add_argument(
305
+ "--grpc-rere",
306
+ action="store_const",
307
+ dest="transport",
308
+ const=TRANSPORT_TYPE_GRPC_RERE,
309
+ default=TRANSPORT_TYPE_GRPC_RERE,
310
+ help="Use grpc-rere as a transport layer for the client.",
311
+ )
312
+ ex_group.add_argument(
313
+ "--grpc-adapter",
314
+ action="store_const",
315
+ dest="transport",
316
+ const=TRANSPORT_TYPE_GRPC_ADAPTER,
317
+ help="Use grpc-adapter as a transport layer for the client.",
318
+ )
319
+ ex_group.add_argument(
203
320
  "--rest",
204
- action="store_true",
321
+ action="store_const",
322
+ dest="transport",
323
+ const=TRANSPORT_TYPE_REST,
205
324
  help="Use REST as a transport layer for the client.",
206
325
  )
207
326
  parser.add_argument(
@@ -213,9 +332,14 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
213
332
  )
214
333
  parser.add_argument(
215
334
  "--server",
216
- default="0.0.0.0:9092",
335
+ default=ADDRESS_FLEET_API_GRPC_RERE,
217
336
  help="Server address",
218
337
  )
338
+ parser.add_argument(
339
+ "--superlink",
340
+ default=ADDRESS_FLEET_API_GRPC_RERE,
341
+ help="SuperLink Fleet API (gRPC-rere) address (IPv4, IPv6, or a domain name)",
342
+ )
219
343
  parser.add_argument(
220
344
  "--max-retries",
221
345
  type=int,
flwr/common/__init__.py CHANGED
@@ -63,43 +63,34 @@ from .typing import Status as Status
63
63
 
64
64
  __all__ = [
65
65
  "Array",
66
- "array_from_numpy",
67
- "bytes_to_ndarray",
68
66
  "ClientMessage",
69
67
  "Code",
70
68
  "Config",
71
69
  "ConfigsRecord",
72
- "configure",
73
70
  "Context",
71
+ "DEFAULT_TTL",
74
72
  "DisconnectRes",
73
+ "Error",
75
74
  "EvaluateIns",
76
75
  "EvaluateRes",
77
- "event",
78
76
  "EventType",
79
77
  "FitIns",
80
78
  "FitRes",
81
- "Error",
79
+ "GRPC_MAX_MESSAGE_LENGTH",
82
80
  "GetParametersIns",
83
81
  "GetParametersRes",
84
82
  "GetPropertiesIns",
85
83
  "GetPropertiesRes",
86
- "GRPC_MAX_MESSAGE_LENGTH",
87
- "log",
88
84
  "Message",
89
85
  "MessageType",
90
86
  "MessageTypeLegacy",
91
- "DEFAULT_TTL",
92
87
  "Metadata",
93
88
  "Metrics",
94
89
  "MetricsAggregationFn",
95
90
  "MetricsRecord",
96
- "ndarray_to_bytes",
97
- "now",
98
91
  "NDArray",
99
92
  "NDArrays",
100
- "ndarrays_to_parameters",
101
93
  "Parameters",
102
- "parameters_to_ndarrays",
103
94
  "ParametersRecord",
104
95
  "Properties",
105
96
  "ReconnectIns",
@@ -107,4 +98,13 @@ __all__ = [
107
98
  "Scalar",
108
99
  "ServerMessage",
109
100
  "Status",
101
+ "array_from_numpy",
102
+ "bytes_to_ndarray",
103
+ "configure",
104
+ "event",
105
+ "log",
106
+ "ndarray_to_bytes",
107
+ "ndarrays_to_parameters",
108
+ "now",
109
+ "parameters_to_ndarrays",
110
110
  ]
flwr/common/config.py ADDED
@@ -0,0 +1,71 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Provide functions for managing global Flower config."""
16
+
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Any, Dict, Optional, Union
20
+
21
+ import tomli
22
+
23
+ from flwr.cli.config_utils import validate_fields
24
+ from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
25
+
26
+
27
+ def get_flwr_dir() -> Path:
28
+ """Return the Flower home directory based on env variables."""
29
+ return Path(
30
+ os.getenv(
31
+ FLWR_HOME,
32
+ f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr",
33
+ )
34
+ )
35
+
36
+
37
+ def get_project_dir(
38
+ fab_id: str, fab_version: str, flwr_dir: Optional[Union[str, Path]] = None
39
+ ) -> Path:
40
+ """Return the project directory based on the given fab_id and fab_version."""
41
+ # Check the fab_id
42
+ if fab_id.count("/") != 1:
43
+ raise ValueError(
44
+ f"Invalid FAB ID: {fab_id}",
45
+ )
46
+ publisher, project_name = fab_id.split("/")
47
+ if flwr_dir is None:
48
+ flwr_dir = get_flwr_dir()
49
+ return Path(flwr_dir) / APP_DIR / publisher / project_name / fab_version
50
+
51
+
52
+ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]:
53
+ """Return pyproject.toml in the given project directory."""
54
+ # Load pyproject.toml file
55
+ toml_path = Path(project_dir) / FAB_CONFIG_FILE
56
+ if not toml_path.is_file():
57
+ raise FileNotFoundError(
58
+ f"Cannot find {FAB_CONFIG_FILE} in {project_dir}",
59
+ )
60
+ with toml_path.open(encoding="utf-8") as toml_file:
61
+ config = tomli.loads(toml_file.read())
62
+
63
+ # Validate pyproject.toml fields
64
+ is_valid, errors, _ = validate_fields(config)
65
+ if not is_valid:
66
+ error_msg = "\n".join([f" - {error}" for error in errors])
67
+ raise ValueError(
68
+ f"Invalid {FAB_CONFIG_FILE}:\n{error_msg}",
69
+ )
70
+
71
+ return config
flwr/common/constant.py CHANGED
@@ -27,6 +27,7 @@ To use the REST API, install `flwr` with the `rest` extra:
27
27
 
28
28
  TRANSPORT_TYPE_GRPC_BIDI = "grpc-bidi"
29
29
  TRANSPORT_TYPE_GRPC_RERE = "grpc-rere"
30
+ TRANSPORT_TYPE_GRPC_ADAPTER = "grpc-adapter"
30
31
  TRANSPORT_TYPE_REST = "rest"
31
32
  TRANSPORT_TYPE_VCE = "vce"
32
33
  TRANSPORT_TYPES = [
@@ -36,6 +37,8 @@ TRANSPORT_TYPES = [
36
37
  TRANSPORT_TYPE_VCE,
37
38
  ]
38
39
 
40
+ SUPEREXEC_DEFAULT_ADDRESS = "0.0.0.0:9093"
41
+
39
42
  # Constants for ping
40
43
  PING_DEFAULT_INTERVAL = 30
41
44
  PING_CALL_TIMEOUT = 5
@@ -43,6 +46,18 @@ PING_BASE_MULTIPLIER = 0.8
43
46
  PING_RANDOM_RANGE = (-0.1, 0.1)
44
47
  PING_MAX_INTERVAL = 1e300
45
48
 
49
+ GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
50
+ GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
51
+
52
+ # Constants for FAB
53
+ APP_DIR = "apps"
54
+ FAB_CONFIG_FILE = "pyproject.toml"
55
+ FLWR_HOME = "FLWR_HOME"
56
+
57
+
58
+ GRPC_ADAPTER_METADATA_FLOWER_VERSION_KEY = "flower-version"
59
+ GRPC_ADAPTER_METADATA_SHOULD_EXIT_KEY = "should-exit"
60
+
46
61
 
47
62
  class MessageType:
48
63
  """Message type."""