flwr-nightly 1.10.0.dev20240612__py3-none-any.whl → 1.10.0.dev20240619__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (53) hide show
  1. flwr/cli/build.py +3 -1
  2. flwr/cli/config_utils.py +53 -3
  3. flwr/cli/install.py +35 -20
  4. flwr/cli/run/run.py +39 -2
  5. flwr/client/__init__.py +1 -1
  6. flwr/client/app.py +22 -10
  7. flwr/client/grpc_adapter_client/__init__.py +15 -0
  8. flwr/client/grpc_adapter_client/connection.py +94 -0
  9. flwr/client/grpc_client/connection.py +5 -1
  10. flwr/client/grpc_rere_client/connection.py +8 -1
  11. flwr/client/grpc_rere_client/grpc_adapter.py +133 -0
  12. flwr/client/mod/__init__.py +3 -3
  13. flwr/client/rest_client/connection.py +9 -1
  14. flwr/client/supernode/app.py +140 -40
  15. flwr/common/__init__.py +12 -12
  16. flwr/common/config.py +71 -0
  17. flwr/common/constant.py +15 -0
  18. flwr/common/object_ref.py +39 -5
  19. flwr/common/record/__init__.py +1 -1
  20. flwr/common/telemetry.py +4 -0
  21. flwr/common/typing.py +9 -0
  22. flwr/proto/exec_pb2.py +34 -0
  23. flwr/proto/exec_pb2.pyi +55 -0
  24. flwr/proto/exec_pb2_grpc.py +101 -0
  25. flwr/proto/exec_pb2_grpc.pyi +41 -0
  26. flwr/proto/fab_pb2.py +30 -0
  27. flwr/proto/fab_pb2.pyi +56 -0
  28. flwr/proto/fab_pb2_grpc.py +4 -0
  29. flwr/proto/fab_pb2_grpc.pyi +4 -0
  30. flwr/server/__init__.py +2 -2
  31. flwr/server/app.py +62 -25
  32. flwr/server/run_serverapp.py +4 -2
  33. flwr/server/strategy/__init__.py +2 -2
  34. flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
  35. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
  36. flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +4 -0
  37. flwr/server/superlink/fleet/message_handler/message_handler.py +3 -3
  38. flwr/server/superlink/fleet/vce/vce_api.py +3 -1
  39. flwr/server/superlink/state/in_memory_state.py +8 -5
  40. flwr/server/superlink/state/sqlite_state.py +6 -3
  41. flwr/server/superlink/state/state.py +5 -4
  42. flwr/simulation/__init__.py +4 -1
  43. flwr/simulation/run_simulation.py +22 -0
  44. flwr/superexec/__init__.py +21 -0
  45. flwr/superexec/app.py +178 -0
  46. flwr/superexec/exec_grpc.py +51 -0
  47. flwr/superexec/exec_servicer.py +65 -0
  48. flwr/superexec/executor.py +54 -0
  49. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/METADATA +1 -1
  50. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/RECORD +53 -34
  51. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/entry_points.txt +1 -0
  52. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/LICENSE +0 -0
  53. {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/WHEEL +0 -0
@@ -0,0 +1,131 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Fleet API gRPC adapter servicer."""
16
+
17
+
18
+ from logging import DEBUG, INFO
19
+ from typing import Callable, Type, TypeVar
20
+
21
+ import grpc
22
+ from google.protobuf.message import Message as GrpcMessage
23
+
24
+ from flwr.common.logger import log
25
+ from flwr.proto import grpcadapter_pb2_grpc # pylint: disable=E0611
26
+ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
27
+ CreateNodeRequest,
28
+ CreateNodeResponse,
29
+ DeleteNodeRequest,
30
+ DeleteNodeResponse,
31
+ PingRequest,
32
+ PingResponse,
33
+ PullTaskInsRequest,
34
+ PullTaskInsResponse,
35
+ PushTaskResRequest,
36
+ PushTaskResResponse,
37
+ )
38
+ from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
39
+ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
40
+ from flwr.server.superlink.fleet.message_handler import message_handler
41
+ from flwr.server.superlink.state import StateFactory
42
+
43
+ T = TypeVar("T", bound=GrpcMessage)
44
+
45
+
46
+ def _handle(
47
+ msg_container: MessageContainer,
48
+ request_type: Type[T],
49
+ handler: Callable[[T], GrpcMessage],
50
+ ) -> MessageContainer:
51
+ req = request_type.FromString(msg_container.grpc_message_content)
52
+ res = handler(req)
53
+ return MessageContainer(
54
+ metadata={},
55
+ grpc_message_name=res.__class__.__qualname__,
56
+ grpc_message_content=res.SerializeToString(),
57
+ )
58
+
59
+
60
+ class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
61
+ """Fleet API via GrpcAdapter servicer."""
62
+
63
+ def __init__(self, state_factory: StateFactory) -> None:
64
+ self.state_factory = state_factory
65
+
66
+ def SendReceive(
67
+ self, request: MessageContainer, context: grpc.ServicerContext
68
+ ) -> MessageContainer:
69
+ """."""
70
+ log(DEBUG, "GrpcAdapterServicer.SendReceive")
71
+ if request.grpc_message_name == CreateNodeRequest.__qualname__:
72
+ return _handle(request, CreateNodeRequest, self._create_node)
73
+ if request.grpc_message_name == DeleteNodeRequest.__qualname__:
74
+ return _handle(request, DeleteNodeRequest, self._delete_node)
75
+ if request.grpc_message_name == PingRequest.__qualname__:
76
+ return _handle(request, PingRequest, self._ping)
77
+ if request.grpc_message_name == PullTaskInsRequest.__qualname__:
78
+ return _handle(request, PullTaskInsRequest, self._pull_task_ins)
79
+ if request.grpc_message_name == PushTaskResRequest.__qualname__:
80
+ return _handle(request, PushTaskResRequest, self._push_task_res)
81
+ if request.grpc_message_name == GetRunRequest.__qualname__:
82
+ return _handle(request, GetRunRequest, self._get_run)
83
+ raise ValueError(f"Invalid grpc_message_name: {request.grpc_message_name}")
84
+
85
+ def _create_node(self, request: CreateNodeRequest) -> CreateNodeResponse:
86
+ """."""
87
+ log(INFO, "GrpcAdapter.CreateNode")
88
+ return message_handler.create_node(
89
+ request=request,
90
+ state=self.state_factory.state(),
91
+ )
92
+
93
+ def _delete_node(self, request: DeleteNodeRequest) -> DeleteNodeResponse:
94
+ """."""
95
+ log(INFO, "GrpcAdapter.DeleteNode")
96
+ return message_handler.delete_node(
97
+ request=request,
98
+ state=self.state_factory.state(),
99
+ )
100
+
101
+ def _ping(self, request: PingRequest) -> PingResponse:
102
+ """."""
103
+ log(DEBUG, "GrpcAdapter.Ping")
104
+ return message_handler.ping(
105
+ request=request,
106
+ state=self.state_factory.state(),
107
+ )
108
+
109
+ def _pull_task_ins(self, request: PullTaskInsRequest) -> PullTaskInsResponse:
110
+ """Pull TaskIns."""
111
+ log(INFO, "GrpcAdapter.PullTaskIns")
112
+ return message_handler.pull_task_ins(
113
+ request=request,
114
+ state=self.state_factory.state(),
115
+ )
116
+
117
+ def _push_task_res(self, request: PushTaskResRequest) -> PushTaskResResponse:
118
+ """Push TaskRes."""
119
+ log(INFO, "GrpcAdapter.PushTaskRes")
120
+ return message_handler.push_task_res(
121
+ request=request,
122
+ state=self.state_factory.state(),
123
+ )
124
+
125
+ def _get_run(self, request: GetRunRequest) -> GetRunResponse:
126
+ """Get run information."""
127
+ log(INFO, "GrpcAdapter.GetRun")
128
+ return message_handler.get_run(
129
+ request=request,
130
+ state=self.state_factory.state(),
131
+ )
@@ -29,6 +29,9 @@ from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611
29
29
  )
30
30
  from flwr.server.client_manager import ClientManager
31
31
  from flwr.server.superlink.driver.driver_servicer import DriverServicer
32
+ from flwr.server.superlink.fleet.grpc_adapter.grpc_adapter_servicer import (
33
+ GrpcAdapterServicer,
34
+ )
32
35
  from flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer import (
33
36
  FlowerServiceServicer,
34
37
  )
@@ -154,6 +157,7 @@ def start_grpc_server( # pylint: disable=too-many-arguments
154
157
  def generic_create_grpc_server( # pylint: disable=too-many-arguments
155
158
  servicer_and_add_fn: Union[
156
159
  Tuple[FleetServicer, AddServicerToServerFn],
160
+ Tuple[GrpcAdapterServicer, AddServicerToServerFn],
157
161
  Tuple[FlowerServiceServicer, AddServicerToServerFn],
158
162
  Tuple[DriverServicer, AddServicerToServerFn],
159
163
  ],
@@ -112,6 +112,6 @@ def get_run(
112
112
  request: GetRunRequest, state: State # pylint: disable=W0613
113
113
  ) -> GetRunResponse:
114
114
  """Get run information."""
115
- run_id, fab_id, fab_version = state.get_run(request.run_id)
116
- run = Run(run_id=run_id, fab_id=fab_id, fab_version=fab_version)
117
- return GetRunResponse(run=run)
115
+ run = state.get_run(request.run_id)
116
+ run_proto = None if run is None else Run(**vars(run))
117
+ return GetRunResponse(run=run_proto)
@@ -20,6 +20,7 @@ import sys
20
20
  import time
21
21
  import traceback
22
22
  from logging import DEBUG, ERROR, INFO, WARN
23
+ from pathlib import Path
23
24
  from typing import Callable, Dict, List, Optional
24
25
 
25
26
  from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
@@ -274,6 +275,7 @@ def start_vce(
274
275
  # Use mapping constructed externally. This also means nodes
275
276
  # have previously being registered.
276
277
  nodes_mapping = existing_nodes_mapping
278
+ app_dir = str(Path(app_dir).absolute())
277
279
 
278
280
  if not state_factory:
279
281
  log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
@@ -323,7 +325,7 @@ def start_vce(
323
325
  if app_dir is not None:
324
326
  sys.path.insert(0, app_dir)
325
327
 
326
- app: ClientApp = load_app(client_app_attr, LoadClientAppError)
328
+ app: ClientApp = load_app(client_app_attr, LoadClientAppError, app_dir)
327
329
 
328
330
  if not isinstance(app, ClientApp):
329
331
  raise LoadClientAppError(
@@ -23,6 +23,7 @@ from typing import Dict, List, Optional, Set, Tuple
23
23
  from uuid import UUID, uuid4
24
24
 
25
25
  from flwr.common import log, now
26
+ from flwr.common.typing import Run
26
27
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
27
28
  from flwr.server.superlink.state.state import State
28
29
  from flwr.server.utils import validate_task_ins_or_res
@@ -40,7 +41,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
40
41
  self.public_key_to_node_id: Dict[bytes, int] = {}
41
42
 
42
43
  # Map run_id to (fab_id, fab_version)
43
- self.run_ids: Dict[int, Tuple[str, str]] = {}
44
+ self.run_ids: Dict[int, Run] = {}
44
45
  self.task_ins_store: Dict[UUID, TaskIns] = {}
45
46
  self.task_res_store: Dict[UUID, TaskRes] = {}
46
47
 
@@ -281,7 +282,9 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
281
282
  run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
282
283
 
283
284
  if run_id not in self.run_ids:
284
- self.run_ids[run_id] = (fab_id, fab_version)
285
+ self.run_ids[run_id] = Run(
286
+ run_id=run_id, fab_id=fab_id, fab_version=fab_version
287
+ )
285
288
  return run_id
286
289
  log(ERROR, "Unexpected run creation failure.")
287
290
  return 0
@@ -319,13 +322,13 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
319
322
  """Retrieve all currently stored `client_public_keys` as a set."""
320
323
  return self.client_public_keys
321
324
 
322
- def get_run(self, run_id: int) -> Tuple[int, str, str]:
325
+ def get_run(self, run_id: int) -> Optional[Run]:
323
326
  """Retrieve information about the run with the specified `run_id`."""
324
327
  with self.lock:
325
328
  if run_id not in self.run_ids:
326
329
  log(ERROR, "`run_id` is invalid")
327
- return 0, "", ""
328
- return run_id, *self.run_ids[run_id]
330
+ return None
331
+ return self.run_ids[run_id]
329
332
 
330
333
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
331
334
  """Acknowledge a ping received from a node, serving as a heartbeat."""
@@ -24,6 +24,7 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
24
24
  from uuid import UUID, uuid4
25
25
 
26
26
  from flwr.common import log, now
27
+ from flwr.common.typing import Run
27
28
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
28
29
  from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
29
30
  from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
@@ -680,15 +681,17 @@ class SqliteState(State): # pylint: disable=R0904
680
681
  result: Set[bytes] = {row["public_key"] for row in rows}
681
682
  return result
682
683
 
683
- def get_run(self, run_id: int) -> Tuple[int, str, str]:
684
+ def get_run(self, run_id: int) -> Optional[Run]:
684
685
  """Retrieve information about the run with the specified `run_id`."""
685
686
  query = "SELECT * FROM run WHERE run_id = ?;"
686
687
  try:
687
688
  row = self.query(query, (run_id,))[0]
688
- return run_id, row["fab_id"], row["fab_version"]
689
+ return Run(
690
+ run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"]
691
+ )
689
692
  except sqlite3.IntegrityError:
690
693
  log(ERROR, "`run_id` does not exist.")
691
- return 0, "", ""
694
+ return None
692
695
 
693
696
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
694
697
  """Acknowledge a ping received from a node, serving as a heartbeat."""
@@ -16,9 +16,10 @@
16
16
 
17
17
 
18
18
  import abc
19
- from typing import List, Optional, Set, Tuple
19
+ from typing import List, Optional, Set
20
20
  from uuid import UUID
21
21
 
22
+ from flwr.common.typing import Run
22
23
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
23
24
 
24
25
 
@@ -160,7 +161,7 @@ class State(abc.ABC): # pylint: disable=R0904
160
161
  """Create a new run for the specified `fab_id` and `fab_version`."""
161
162
 
162
163
  @abc.abstractmethod
163
- def get_run(self, run_id: int) -> Tuple[int, str, str]:
164
+ def get_run(self, run_id: int) -> Optional[Run]:
164
165
  """Retrieve information about the run with the specified `run_id`.
165
166
 
166
167
  Parameters
@@ -170,8 +171,8 @@ class State(abc.ABC): # pylint: disable=R0904
170
171
 
171
172
  Returns
172
173
  -------
173
- Tuple[int, str, str]
174
- A tuple containing three elements:
174
+ Optional[Run]
175
+ A dataclass instance containing three elements if `run_id` is valid:
175
176
  - `run_id`: The identifier of the run, same as the specified `run_id`.
176
177
  - `fab_id`: The identifier of the FAB used in the specified run.
177
178
  - `fab_version`: The version of the FAB used in the specified run.
@@ -36,4 +36,7 @@ To install the necessary dependencies, install `flwr` with the `simulation` extr
36
36
  raise ImportError(RAY_IMPORT_ERROR)
37
37
 
38
38
 
39
- __all__ = ["start_simulation", "run_simulation"]
39
+ __all__ = [
40
+ "run_simulation",
41
+ "start_simulation",
42
+ ]
@@ -53,6 +53,7 @@ def run_simulation_from_cli() -> None:
53
53
  backend_name=args.backend,
54
54
  backend_config=backend_config_dict,
55
55
  app_dir=args.app_dir,
56
+ run_id=args.run_id,
56
57
  enable_tf_gpu_growth=args.enable_tf_gpu_growth,
57
58
  verbose_logging=args.verbose,
58
59
  )
@@ -168,6 +169,13 @@ def run_serverapp_th(
168
169
  return serverapp_th
169
170
 
170
171
 
172
+ def _init_run_id(driver: InMemoryDriver, state: StateFactory, run_id: int) -> None:
173
+ """Create a run with a given `run_id`."""
174
+ log(DEBUG, "Pre-registering run with id %s", run_id)
175
+ state.state().run_ids[run_id] = ("", "") # type: ignore
176
+ driver.run_id = run_id
177
+
178
+
171
179
  # pylint: disable=too-many-locals
172
180
  def _main_loop(
173
181
  num_supernodes: int,
@@ -175,6 +183,7 @@ def _main_loop(
175
183
  backend_config_stream: str,
176
184
  app_dir: str,
177
185
  enable_tf_gpu_growth: bool,
186
+ run_id: Optional[int] = None,
178
187
  client_app: Optional[ClientApp] = None,
179
188
  client_app_attr: Optional[str] = None,
180
189
  server_app: Optional[ServerApp] = None,
@@ -195,6 +204,9 @@ def _main_loop(
195
204
  # Initialize Driver
196
205
  driver = InMemoryDriver(state_factory)
197
206
 
207
+ if run_id:
208
+ _init_run_id(driver, state_factory, run_id)
209
+
198
210
  # Get and run ServerApp thread
199
211
  serverapp_th = run_serverapp_th(
200
212
  server_app_attr=server_app_attr,
@@ -244,6 +256,7 @@ def _run_simulation(
244
256
  client_app_attr: Optional[str] = None,
245
257
  server_app_attr: Optional[str] = None,
246
258
  app_dir: str = "",
259
+ run_id: Optional[int] = None,
247
260
  enable_tf_gpu_growth: bool = False,
248
261
  verbose_logging: bool = False,
249
262
  ) -> None:
@@ -283,6 +296,9 @@ def _run_simulation(
283
296
  Add specified directory to the PYTHONPATH and load `ClientApp` from there.
284
297
  (Default: current working directory.)
285
298
 
299
+ run_id : Optional[int]
300
+ An integer specifying the ID of the run started when running this function.
301
+
286
302
  enable_tf_gpu_growth : bool (default: False)
287
303
  A boolean to indicate whether to enable GPU growth on the main thread. This is
288
304
  desirable if you make use of a TensorFlow model on your `ServerApp` while
@@ -322,6 +338,7 @@ def _run_simulation(
322
338
  backend_config_stream,
323
339
  app_dir,
324
340
  enable_tf_gpu_growth,
341
+ run_id,
325
342
  client_app,
326
343
  client_app_attr,
327
344
  server_app,
@@ -413,5 +430,10 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
413
430
  "ClientApp and ServerApp from there."
414
431
  " Default: current working directory.",
415
432
  )
433
+ parser.add_argument(
434
+ "--run-id",
435
+ type=int,
436
+ help="Sets the ID of the run started by the Simulation Engine.",
437
+ )
416
438
 
417
439
  return parser
@@ -0,0 +1,21 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower SuperExec service."""
16
+
17
+ from .app import run_superexec as run_superexec
18
+
19
+ __all__ = [
20
+ "run_superexec",
21
+ ]
flwr/superexec/app.py ADDED
@@ -0,0 +1,178 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Flower SuperExec app."""
16
+
17
+ import argparse
18
+ import sys
19
+ from logging import INFO, WARN
20
+ from pathlib import Path
21
+ from typing import Optional, Tuple
22
+
23
+ import grpc
24
+
25
+ from flwr.common import EventType, event, log
26
+ from flwr.common.address import parse_address
27
+ from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
28
+ from flwr.common.exit_handlers import register_exit_handlers
29
+ from flwr.common.object_ref import load_app, validate
30
+
31
+ from .exec_grpc import run_superexec_api_grpc
32
+ from .executor import Executor
33
+
34
+
35
+ def run_superexec() -> None:
36
+ """Run Flower SuperExec."""
37
+ log(INFO, "Starting Flower SuperExec")
38
+
39
+ event(EventType.RUN_SUPEREXEC_ENTER)
40
+
41
+ args = _parse_args_run_superexec().parse_args()
42
+
43
+ # Parse IP address
44
+ parsed_address = parse_address(args.address)
45
+ if not parsed_address:
46
+ sys.exit(f"SuperExec IP address ({args.address}) cannot be parsed.")
47
+ host, port, is_v6 = parsed_address
48
+ address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
49
+
50
+ # Obtain certificates
51
+ certificates = _try_obtain_certificates(args)
52
+
53
+ # Start SuperExec API
54
+ superexec_server: grpc.Server = run_superexec_api_grpc(
55
+ address=address,
56
+ executor=_load_executor(args),
57
+ certificates=certificates,
58
+ )
59
+
60
+ grpc_servers = [superexec_server]
61
+
62
+ # Graceful shutdown
63
+ register_exit_handlers(
64
+ event_type=EventType.RUN_SUPEREXEC_LEAVE,
65
+ grpc_servers=grpc_servers,
66
+ bckg_threads=None,
67
+ )
68
+
69
+ superexec_server.wait_for_termination()
70
+
71
+
72
+ def _parse_args_run_superexec() -> argparse.ArgumentParser:
73
+ """Parse command line arguments for SuperExec."""
74
+ parser = argparse.ArgumentParser(
75
+ description="Start a Flower SuperExec",
76
+ )
77
+ parser.add_argument(
78
+ "executor",
79
+ help="For example: `deployment:exec` or `project.package.module:wrapper.exec`.",
80
+ )
81
+ parser.add_argument(
82
+ "--address",
83
+ help="SuperExec (gRPC) server address (IPv4, IPv6, or a domain name)",
84
+ default=SUPEREXEC_DEFAULT_ADDRESS,
85
+ )
86
+ parser.add_argument(
87
+ "--executor-dir",
88
+ help="The directory for the executor.",
89
+ default=".",
90
+ )
91
+ parser.add_argument(
92
+ "--insecure",
93
+ action="store_true",
94
+ help="Run the SuperExec without HTTPS, regardless of whether certificate "
95
+ "paths are provided. By default, the server runs with HTTPS enabled. "
96
+ "Use this flag only if you understand the risks.",
97
+ )
98
+ parser.add_argument(
99
+ "--ssl-certfile",
100
+ help="SuperExec server SSL certificate file (as a path str) "
101
+ "to create a secure connection.",
102
+ type=str,
103
+ default=None,
104
+ )
105
+ parser.add_argument(
106
+ "--ssl-keyfile",
107
+ help="SuperExec server SSL private key file (as a path str) "
108
+ "to create a secure connection.",
109
+ type=str,
110
+ )
111
+ parser.add_argument(
112
+ "--ssl-ca-certfile",
113
+ help="SuperExec server SSL CA certificate file (as a path str) "
114
+ "to create a secure connection.",
115
+ type=str,
116
+ )
117
+ return parser
118
+
119
+
120
+ def _try_obtain_certificates(
121
+ args: argparse.Namespace,
122
+ ) -> Optional[Tuple[bytes, bytes, bytes]]:
123
+ # Obtain certificates
124
+ if args.insecure:
125
+ log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
126
+ return None
127
+ # Check if certificates are provided
128
+ if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
129
+ if not Path.is_file(args.ssl_ca_certfile):
130
+ sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
131
+ if not Path.is_file(args.ssl_certfile):
132
+ sys.exit("Path argument `--ssl-certfile` does not point to a file.")
133
+ if not Path.is_file(args.ssl_keyfile):
134
+ sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
135
+ certificates = (
136
+ Path(args.ssl_ca_certfile).read_bytes(), # CA certificate
137
+ Path(args.ssl_certfile).read_bytes(), # server certificate
138
+ Path(args.ssl_keyfile).read_bytes(), # server private key
139
+ )
140
+ return certificates
141
+ if args.ssl_certfile or args.ssl_keyfile or args.ssl_ca_certfile:
142
+ sys.exit(
143
+ "You need to provide valid file paths to `--ssl-certfile`, "
144
+ "`--ssl-keyfile`, and `—-ssl-ca-certfile` to create a secure "
145
+ "connection in SuperExec server (gRPC-rere)."
146
+ )
147
+ sys.exit(
148
+ "Certificates are required unless running in insecure mode. "
149
+ "Please provide certificate paths to `--ssl-certfile`, "
150
+ "`--ssl-keyfile`, and `—-ssl-ca-certfile` or run the server "
151
+ "in insecure mode using '--insecure' if you understand the risks."
152
+ )
153
+
154
+
155
+ def _load_executor(
156
+ args: argparse.Namespace,
157
+ ) -> Executor:
158
+ """Get the executor plugin."""
159
+ if args.executor_dir is not None:
160
+ sys.path.insert(0, args.executor_dir)
161
+
162
+ executor_ref: str = args.executor
163
+ valid, error_msg = validate(executor_ref)
164
+ if not valid and error_msg:
165
+ raise LoadExecutorError(error_msg) from None
166
+
167
+ executor = load_app(executor_ref, LoadExecutorError, args.executor_dir)
168
+
169
+ if not isinstance(executor, Executor):
170
+ raise LoadExecutorError(
171
+ f"Attribute {executor_ref} is not of type {Executor}",
172
+ ) from None
173
+
174
+ return executor
175
+
176
+
177
+ class LoadExecutorError(Exception):
178
+ """Error when trying to load `Executor`."""
@@ -0,0 +1,51 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """SuperExec gRPC API."""
16
+
17
+ from logging import INFO
18
+ from typing import Optional, Tuple
19
+
20
+ import grpc
21
+
22
+ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
23
+ from flwr.common.logger import log
24
+ from flwr.proto.exec_pb2_grpc import add_ExecServicer_to_server
25
+ from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
26
+
27
+ from .exec_servicer import ExecServicer
28
+ from .executor import Executor
29
+
30
+
31
+ def run_superexec_api_grpc(
32
+ address: str,
33
+ executor: Executor,
34
+ certificates: Optional[Tuple[bytes, bytes, bytes]],
35
+ ) -> grpc.Server:
36
+ """Run SuperExec API (gRPC, request-response)."""
37
+ exec_servicer: grpc.Server = ExecServicer(
38
+ executor=executor,
39
+ )
40
+ superexec_add_servicer_to_server_fn = add_ExecServicer_to_server
41
+ superexec_grpc_server = generic_create_grpc_server(
42
+ servicer_and_add_fn=(exec_servicer, superexec_add_servicer_to_server_fn),
43
+ server_address=address,
44
+ max_message_length=GRPC_MAX_MESSAGE_LENGTH,
45
+ certificates=certificates,
46
+ )
47
+
48
+ log(INFO, "Flower ECE: Starting SuperExec API (gRPC-rere) on %s", address)
49
+ superexec_grpc_server.start()
50
+
51
+ return superexec_grpc_server