flwr-nightly 1.10.0.dev20240612__py3-none-any.whl → 1.10.0.dev20240619__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/build.py +3 -1
- flwr/cli/config_utils.py +53 -3
- flwr/cli/install.py +35 -20
- flwr/cli/run/run.py +39 -2
- flwr/client/__init__.py +1 -1
- flwr/client/app.py +22 -10
- flwr/client/grpc_adapter_client/__init__.py +15 -0
- flwr/client/grpc_adapter_client/connection.py +94 -0
- flwr/client/grpc_client/connection.py +5 -1
- flwr/client/grpc_rere_client/connection.py +8 -1
- flwr/client/grpc_rere_client/grpc_adapter.py +133 -0
- flwr/client/mod/__init__.py +3 -3
- flwr/client/rest_client/connection.py +9 -1
- flwr/client/supernode/app.py +140 -40
- flwr/common/__init__.py +12 -12
- flwr/common/config.py +71 -0
- flwr/common/constant.py +15 -0
- flwr/common/object_ref.py +39 -5
- flwr/common/record/__init__.py +1 -1
- flwr/common/telemetry.py +4 -0
- flwr/common/typing.py +9 -0
- flwr/proto/exec_pb2.py +34 -0
- flwr/proto/exec_pb2.pyi +55 -0
- flwr/proto/exec_pb2_grpc.py +101 -0
- flwr/proto/exec_pb2_grpc.pyi +41 -0
- flwr/proto/fab_pb2.py +30 -0
- flwr/proto/fab_pb2.pyi +56 -0
- flwr/proto/fab_pb2_grpc.py +4 -0
- flwr/proto/fab_pb2_grpc.pyi +4 -0
- flwr/server/__init__.py +2 -2
- flwr/server/app.py +62 -25
- flwr/server/run_serverapp.py +4 -2
- flwr/server/strategy/__init__.py +2 -2
- flwr/server/superlink/fleet/grpc_adapter/__init__.py +15 -0
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +131 -0
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +4 -0
- flwr/server/superlink/fleet/message_handler/message_handler.py +3 -3
- flwr/server/superlink/fleet/vce/vce_api.py +3 -1
- flwr/server/superlink/state/in_memory_state.py +8 -5
- flwr/server/superlink/state/sqlite_state.py +6 -3
- flwr/server/superlink/state/state.py +5 -4
- flwr/simulation/__init__.py +4 -1
- flwr/simulation/run_simulation.py +22 -0
- flwr/superexec/__init__.py +21 -0
- flwr/superexec/app.py +178 -0
- flwr/superexec/exec_grpc.py +51 -0
- flwr/superexec/exec_servicer.py +65 -0
- flwr/superexec/executor.py +54 -0
- {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/METADATA +1 -1
- {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/RECORD +53 -34
- {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/entry_points.txt +1 -0
- {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.10.0.dev20240612.dist-info → flwr_nightly-1.10.0.dev20240619.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Fleet API gRPC adapter servicer."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from logging import DEBUG, INFO
|
|
19
|
+
from typing import Callable, Type, TypeVar
|
|
20
|
+
|
|
21
|
+
import grpc
|
|
22
|
+
from google.protobuf.message import Message as GrpcMessage
|
|
23
|
+
|
|
24
|
+
from flwr.common.logger import log
|
|
25
|
+
from flwr.proto import grpcadapter_pb2_grpc # pylint: disable=E0611
|
|
26
|
+
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
27
|
+
CreateNodeRequest,
|
|
28
|
+
CreateNodeResponse,
|
|
29
|
+
DeleteNodeRequest,
|
|
30
|
+
DeleteNodeResponse,
|
|
31
|
+
PingRequest,
|
|
32
|
+
PingResponse,
|
|
33
|
+
PullTaskInsRequest,
|
|
34
|
+
PullTaskInsResponse,
|
|
35
|
+
PushTaskResRequest,
|
|
36
|
+
PushTaskResResponse,
|
|
37
|
+
)
|
|
38
|
+
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
39
|
+
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
40
|
+
from flwr.server.superlink.fleet.message_handler import message_handler
|
|
41
|
+
from flwr.server.superlink.state import StateFactory
|
|
42
|
+
|
|
43
|
+
T = TypeVar("T", bound=GrpcMessage)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _handle(
|
|
47
|
+
msg_container: MessageContainer,
|
|
48
|
+
request_type: Type[T],
|
|
49
|
+
handler: Callable[[T], GrpcMessage],
|
|
50
|
+
) -> MessageContainer:
|
|
51
|
+
req = request_type.FromString(msg_container.grpc_message_content)
|
|
52
|
+
res = handler(req)
|
|
53
|
+
return MessageContainer(
|
|
54
|
+
metadata={},
|
|
55
|
+
grpc_message_name=res.__class__.__qualname__,
|
|
56
|
+
grpc_message_content=res.SerializeToString(),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
|
|
61
|
+
"""Fleet API via GrpcAdapter servicer."""
|
|
62
|
+
|
|
63
|
+
def __init__(self, state_factory: StateFactory) -> None:
|
|
64
|
+
self.state_factory = state_factory
|
|
65
|
+
|
|
66
|
+
def SendReceive(
|
|
67
|
+
self, request: MessageContainer, context: grpc.ServicerContext
|
|
68
|
+
) -> MessageContainer:
|
|
69
|
+
"""."""
|
|
70
|
+
log(DEBUG, "GrpcAdapterServicer.SendReceive")
|
|
71
|
+
if request.grpc_message_name == CreateNodeRequest.__qualname__:
|
|
72
|
+
return _handle(request, CreateNodeRequest, self._create_node)
|
|
73
|
+
if request.grpc_message_name == DeleteNodeRequest.__qualname__:
|
|
74
|
+
return _handle(request, DeleteNodeRequest, self._delete_node)
|
|
75
|
+
if request.grpc_message_name == PingRequest.__qualname__:
|
|
76
|
+
return _handle(request, PingRequest, self._ping)
|
|
77
|
+
if request.grpc_message_name == PullTaskInsRequest.__qualname__:
|
|
78
|
+
return _handle(request, PullTaskInsRequest, self._pull_task_ins)
|
|
79
|
+
if request.grpc_message_name == PushTaskResRequest.__qualname__:
|
|
80
|
+
return _handle(request, PushTaskResRequest, self._push_task_res)
|
|
81
|
+
if request.grpc_message_name == GetRunRequest.__qualname__:
|
|
82
|
+
return _handle(request, GetRunRequest, self._get_run)
|
|
83
|
+
raise ValueError(f"Invalid grpc_message_name: {request.grpc_message_name}")
|
|
84
|
+
|
|
85
|
+
def _create_node(self, request: CreateNodeRequest) -> CreateNodeResponse:
|
|
86
|
+
"""."""
|
|
87
|
+
log(INFO, "GrpcAdapter.CreateNode")
|
|
88
|
+
return message_handler.create_node(
|
|
89
|
+
request=request,
|
|
90
|
+
state=self.state_factory.state(),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def _delete_node(self, request: DeleteNodeRequest) -> DeleteNodeResponse:
|
|
94
|
+
"""."""
|
|
95
|
+
log(INFO, "GrpcAdapter.DeleteNode")
|
|
96
|
+
return message_handler.delete_node(
|
|
97
|
+
request=request,
|
|
98
|
+
state=self.state_factory.state(),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def _ping(self, request: PingRequest) -> PingResponse:
|
|
102
|
+
"""."""
|
|
103
|
+
log(DEBUG, "GrpcAdapter.Ping")
|
|
104
|
+
return message_handler.ping(
|
|
105
|
+
request=request,
|
|
106
|
+
state=self.state_factory.state(),
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def _pull_task_ins(self, request: PullTaskInsRequest) -> PullTaskInsResponse:
|
|
110
|
+
"""Pull TaskIns."""
|
|
111
|
+
log(INFO, "GrpcAdapter.PullTaskIns")
|
|
112
|
+
return message_handler.pull_task_ins(
|
|
113
|
+
request=request,
|
|
114
|
+
state=self.state_factory.state(),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def _push_task_res(self, request: PushTaskResRequest) -> PushTaskResResponse:
|
|
118
|
+
"""Push TaskRes."""
|
|
119
|
+
log(INFO, "GrpcAdapter.PushTaskRes")
|
|
120
|
+
return message_handler.push_task_res(
|
|
121
|
+
request=request,
|
|
122
|
+
state=self.state_factory.state(),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def _get_run(self, request: GetRunRequest) -> GetRunResponse:
|
|
126
|
+
"""Get run information."""
|
|
127
|
+
log(INFO, "GrpcAdapter.GetRun")
|
|
128
|
+
return message_handler.get_run(
|
|
129
|
+
request=request,
|
|
130
|
+
state=self.state_factory.state(),
|
|
131
|
+
)
|
|
@@ -29,6 +29,9 @@ from flwr.proto.transport_pb2_grpc import ( # pylint: disable=E0611
|
|
|
29
29
|
)
|
|
30
30
|
from flwr.server.client_manager import ClientManager
|
|
31
31
|
from flwr.server.superlink.driver.driver_servicer import DriverServicer
|
|
32
|
+
from flwr.server.superlink.fleet.grpc_adapter.grpc_adapter_servicer import (
|
|
33
|
+
GrpcAdapterServicer,
|
|
34
|
+
)
|
|
32
35
|
from flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer import (
|
|
33
36
|
FlowerServiceServicer,
|
|
34
37
|
)
|
|
@@ -154,6 +157,7 @@ def start_grpc_server( # pylint: disable=too-many-arguments
|
|
|
154
157
|
def generic_create_grpc_server( # pylint: disable=too-many-arguments
|
|
155
158
|
servicer_and_add_fn: Union[
|
|
156
159
|
Tuple[FleetServicer, AddServicerToServerFn],
|
|
160
|
+
Tuple[GrpcAdapterServicer, AddServicerToServerFn],
|
|
157
161
|
Tuple[FlowerServiceServicer, AddServicerToServerFn],
|
|
158
162
|
Tuple[DriverServicer, AddServicerToServerFn],
|
|
159
163
|
],
|
|
@@ -112,6 +112,6 @@ def get_run(
|
|
|
112
112
|
request: GetRunRequest, state: State # pylint: disable=W0613
|
|
113
113
|
) -> GetRunResponse:
|
|
114
114
|
"""Get run information."""
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
return GetRunResponse(run=
|
|
115
|
+
run = state.get_run(request.run_id)
|
|
116
|
+
run_proto = None if run is None else Run(**vars(run))
|
|
117
|
+
return GetRunResponse(run=run_proto)
|
|
@@ -20,6 +20,7 @@ import sys
|
|
|
20
20
|
import time
|
|
21
21
|
import traceback
|
|
22
22
|
from logging import DEBUG, ERROR, INFO, WARN
|
|
23
|
+
from pathlib import Path
|
|
23
24
|
from typing import Callable, Dict, List, Optional
|
|
24
25
|
|
|
25
26
|
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
|
|
@@ -274,6 +275,7 @@ def start_vce(
|
|
|
274
275
|
# Use mapping constructed externally. This also means nodes
|
|
275
276
|
# have previously being registered.
|
|
276
277
|
nodes_mapping = existing_nodes_mapping
|
|
278
|
+
app_dir = str(Path(app_dir).absolute())
|
|
277
279
|
|
|
278
280
|
if not state_factory:
|
|
279
281
|
log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
|
|
@@ -323,7 +325,7 @@ def start_vce(
|
|
|
323
325
|
if app_dir is not None:
|
|
324
326
|
sys.path.insert(0, app_dir)
|
|
325
327
|
|
|
326
|
-
app: ClientApp = load_app(client_app_attr, LoadClientAppError)
|
|
328
|
+
app: ClientApp = load_app(client_app_attr, LoadClientAppError, app_dir)
|
|
327
329
|
|
|
328
330
|
if not isinstance(app, ClientApp):
|
|
329
331
|
raise LoadClientAppError(
|
|
@@ -23,6 +23,7 @@ from typing import Dict, List, Optional, Set, Tuple
|
|
|
23
23
|
from uuid import UUID, uuid4
|
|
24
24
|
|
|
25
25
|
from flwr.common import log, now
|
|
26
|
+
from flwr.common.typing import Run
|
|
26
27
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
27
28
|
from flwr.server.superlink.state.state import State
|
|
28
29
|
from flwr.server.utils import validate_task_ins_or_res
|
|
@@ -40,7 +41,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
40
41
|
self.public_key_to_node_id: Dict[bytes, int] = {}
|
|
41
42
|
|
|
42
43
|
# Map run_id to (fab_id, fab_version)
|
|
43
|
-
self.run_ids: Dict[int,
|
|
44
|
+
self.run_ids: Dict[int, Run] = {}
|
|
44
45
|
self.task_ins_store: Dict[UUID, TaskIns] = {}
|
|
45
46
|
self.task_res_store: Dict[UUID, TaskRes] = {}
|
|
46
47
|
|
|
@@ -281,7 +282,9 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
281
282
|
run_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
|
|
282
283
|
|
|
283
284
|
if run_id not in self.run_ids:
|
|
284
|
-
self.run_ids[run_id] = (
|
|
285
|
+
self.run_ids[run_id] = Run(
|
|
286
|
+
run_id=run_id, fab_id=fab_id, fab_version=fab_version
|
|
287
|
+
)
|
|
285
288
|
return run_id
|
|
286
289
|
log(ERROR, "Unexpected run creation failure.")
|
|
287
290
|
return 0
|
|
@@ -319,13 +322,13 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
|
|
|
319
322
|
"""Retrieve all currently stored `client_public_keys` as a set."""
|
|
320
323
|
return self.client_public_keys
|
|
321
324
|
|
|
322
|
-
def get_run(self, run_id: int) ->
|
|
325
|
+
def get_run(self, run_id: int) -> Optional[Run]:
|
|
323
326
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
324
327
|
with self.lock:
|
|
325
328
|
if run_id not in self.run_ids:
|
|
326
329
|
log(ERROR, "`run_id` is invalid")
|
|
327
|
-
return
|
|
328
|
-
return
|
|
330
|
+
return None
|
|
331
|
+
return self.run_ids[run_id]
|
|
329
332
|
|
|
330
333
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
331
334
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
@@ -24,6 +24,7 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast
|
|
|
24
24
|
from uuid import UUID, uuid4
|
|
25
25
|
|
|
26
26
|
from flwr.common import log, now
|
|
27
|
+
from flwr.common.typing import Run
|
|
27
28
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
28
29
|
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
|
|
29
30
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
|
|
@@ -680,15 +681,17 @@ class SqliteState(State): # pylint: disable=R0904
|
|
|
680
681
|
result: Set[bytes] = {row["public_key"] for row in rows}
|
|
681
682
|
return result
|
|
682
683
|
|
|
683
|
-
def get_run(self, run_id: int) ->
|
|
684
|
+
def get_run(self, run_id: int) -> Optional[Run]:
|
|
684
685
|
"""Retrieve information about the run with the specified `run_id`."""
|
|
685
686
|
query = "SELECT * FROM run WHERE run_id = ?;"
|
|
686
687
|
try:
|
|
687
688
|
row = self.query(query, (run_id,))[0]
|
|
688
|
-
return
|
|
689
|
+
return Run(
|
|
690
|
+
run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"]
|
|
691
|
+
)
|
|
689
692
|
except sqlite3.IntegrityError:
|
|
690
693
|
log(ERROR, "`run_id` does not exist.")
|
|
691
|
-
return
|
|
694
|
+
return None
|
|
692
695
|
|
|
693
696
|
def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
|
|
694
697
|
"""Acknowledge a ping received from a node, serving as a heartbeat."""
|
|
@@ -16,9 +16,10 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import abc
|
|
19
|
-
from typing import List, Optional, Set
|
|
19
|
+
from typing import List, Optional, Set
|
|
20
20
|
from uuid import UUID
|
|
21
21
|
|
|
22
|
+
from flwr.common.typing import Run
|
|
22
23
|
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
23
24
|
|
|
24
25
|
|
|
@@ -160,7 +161,7 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
160
161
|
"""Create a new run for the specified `fab_id` and `fab_version`."""
|
|
161
162
|
|
|
162
163
|
@abc.abstractmethod
|
|
163
|
-
def get_run(self, run_id: int) ->
|
|
164
|
+
def get_run(self, run_id: int) -> Optional[Run]:
|
|
164
165
|
"""Retrieve information about the run with the specified `run_id`.
|
|
165
166
|
|
|
166
167
|
Parameters
|
|
@@ -170,8 +171,8 @@ class State(abc.ABC): # pylint: disable=R0904
|
|
|
170
171
|
|
|
171
172
|
Returns
|
|
172
173
|
-------
|
|
173
|
-
|
|
174
|
-
A
|
|
174
|
+
Optional[Run]
|
|
175
|
+
A dataclass instance containing three elements if `run_id` is valid:
|
|
175
176
|
- `run_id`: The identifier of the run, same as the specified `run_id`.
|
|
176
177
|
- `fab_id`: The identifier of the FAB used in the specified run.
|
|
177
178
|
- `fab_version`: The version of the FAB used in the specified run.
|
flwr/simulation/__init__.py
CHANGED
|
@@ -53,6 +53,7 @@ def run_simulation_from_cli() -> None:
|
|
|
53
53
|
backend_name=args.backend,
|
|
54
54
|
backend_config=backend_config_dict,
|
|
55
55
|
app_dir=args.app_dir,
|
|
56
|
+
run_id=args.run_id,
|
|
56
57
|
enable_tf_gpu_growth=args.enable_tf_gpu_growth,
|
|
57
58
|
verbose_logging=args.verbose,
|
|
58
59
|
)
|
|
@@ -168,6 +169,13 @@ def run_serverapp_th(
|
|
|
168
169
|
return serverapp_th
|
|
169
170
|
|
|
170
171
|
|
|
172
|
+
def _init_run_id(driver: InMemoryDriver, state: StateFactory, run_id: int) -> None:
|
|
173
|
+
"""Create a run with a given `run_id`."""
|
|
174
|
+
log(DEBUG, "Pre-registering run with id %s", run_id)
|
|
175
|
+
state.state().run_ids[run_id] = ("", "") # type: ignore
|
|
176
|
+
driver.run_id = run_id
|
|
177
|
+
|
|
178
|
+
|
|
171
179
|
# pylint: disable=too-many-locals
|
|
172
180
|
def _main_loop(
|
|
173
181
|
num_supernodes: int,
|
|
@@ -175,6 +183,7 @@ def _main_loop(
|
|
|
175
183
|
backend_config_stream: str,
|
|
176
184
|
app_dir: str,
|
|
177
185
|
enable_tf_gpu_growth: bool,
|
|
186
|
+
run_id: Optional[int] = None,
|
|
178
187
|
client_app: Optional[ClientApp] = None,
|
|
179
188
|
client_app_attr: Optional[str] = None,
|
|
180
189
|
server_app: Optional[ServerApp] = None,
|
|
@@ -195,6 +204,9 @@ def _main_loop(
|
|
|
195
204
|
# Initialize Driver
|
|
196
205
|
driver = InMemoryDriver(state_factory)
|
|
197
206
|
|
|
207
|
+
if run_id:
|
|
208
|
+
_init_run_id(driver, state_factory, run_id)
|
|
209
|
+
|
|
198
210
|
# Get and run ServerApp thread
|
|
199
211
|
serverapp_th = run_serverapp_th(
|
|
200
212
|
server_app_attr=server_app_attr,
|
|
@@ -244,6 +256,7 @@ def _run_simulation(
|
|
|
244
256
|
client_app_attr: Optional[str] = None,
|
|
245
257
|
server_app_attr: Optional[str] = None,
|
|
246
258
|
app_dir: str = "",
|
|
259
|
+
run_id: Optional[int] = None,
|
|
247
260
|
enable_tf_gpu_growth: bool = False,
|
|
248
261
|
verbose_logging: bool = False,
|
|
249
262
|
) -> None:
|
|
@@ -283,6 +296,9 @@ def _run_simulation(
|
|
|
283
296
|
Add specified directory to the PYTHONPATH and load `ClientApp` from there.
|
|
284
297
|
(Default: current working directory.)
|
|
285
298
|
|
|
299
|
+
run_id : Optional[int]
|
|
300
|
+
An integer specifying the ID of the run started when running this function.
|
|
301
|
+
|
|
286
302
|
enable_tf_gpu_growth : bool (default: False)
|
|
287
303
|
A boolean to indicate whether to enable GPU growth on the main thread. This is
|
|
288
304
|
desirable if you make use of a TensorFlow model on your `ServerApp` while
|
|
@@ -322,6 +338,7 @@ def _run_simulation(
|
|
|
322
338
|
backend_config_stream,
|
|
323
339
|
app_dir,
|
|
324
340
|
enable_tf_gpu_growth,
|
|
341
|
+
run_id,
|
|
325
342
|
client_app,
|
|
326
343
|
client_app_attr,
|
|
327
344
|
server_app,
|
|
@@ -413,5 +430,10 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
|
|
413
430
|
"ClientApp and ServerApp from there."
|
|
414
431
|
" Default: current working directory.",
|
|
415
432
|
)
|
|
433
|
+
parser.add_argument(
|
|
434
|
+
"--run-id",
|
|
435
|
+
type=int,
|
|
436
|
+
help="Sets the ID of the run started by the Simulation Engine.",
|
|
437
|
+
)
|
|
416
438
|
|
|
417
439
|
return parser
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower SuperExec service."""
|
|
16
|
+
|
|
17
|
+
from .app import run_superexec as run_superexec
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"run_superexec",
|
|
21
|
+
]
|
flwr/superexec/app.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower SuperExec app."""
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import sys
|
|
19
|
+
from logging import INFO, WARN
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Optional, Tuple
|
|
22
|
+
|
|
23
|
+
import grpc
|
|
24
|
+
|
|
25
|
+
from flwr.common import EventType, event, log
|
|
26
|
+
from flwr.common.address import parse_address
|
|
27
|
+
from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
|
|
28
|
+
from flwr.common.exit_handlers import register_exit_handlers
|
|
29
|
+
from flwr.common.object_ref import load_app, validate
|
|
30
|
+
|
|
31
|
+
from .exec_grpc import run_superexec_api_grpc
|
|
32
|
+
from .executor import Executor
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def run_superexec() -> None:
|
|
36
|
+
"""Run Flower SuperExec."""
|
|
37
|
+
log(INFO, "Starting Flower SuperExec")
|
|
38
|
+
|
|
39
|
+
event(EventType.RUN_SUPEREXEC_ENTER)
|
|
40
|
+
|
|
41
|
+
args = _parse_args_run_superexec().parse_args()
|
|
42
|
+
|
|
43
|
+
# Parse IP address
|
|
44
|
+
parsed_address = parse_address(args.address)
|
|
45
|
+
if not parsed_address:
|
|
46
|
+
sys.exit(f"SuperExec IP address ({args.address}) cannot be parsed.")
|
|
47
|
+
host, port, is_v6 = parsed_address
|
|
48
|
+
address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"
|
|
49
|
+
|
|
50
|
+
# Obtain certificates
|
|
51
|
+
certificates = _try_obtain_certificates(args)
|
|
52
|
+
|
|
53
|
+
# Start SuperExec API
|
|
54
|
+
superexec_server: grpc.Server = run_superexec_api_grpc(
|
|
55
|
+
address=address,
|
|
56
|
+
executor=_load_executor(args),
|
|
57
|
+
certificates=certificates,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
grpc_servers = [superexec_server]
|
|
61
|
+
|
|
62
|
+
# Graceful shutdown
|
|
63
|
+
register_exit_handlers(
|
|
64
|
+
event_type=EventType.RUN_SUPEREXEC_LEAVE,
|
|
65
|
+
grpc_servers=grpc_servers,
|
|
66
|
+
bckg_threads=None,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
superexec_server.wait_for_termination()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _parse_args_run_superexec() -> argparse.ArgumentParser:
|
|
73
|
+
"""Parse command line arguments for SuperExec."""
|
|
74
|
+
parser = argparse.ArgumentParser(
|
|
75
|
+
description="Start a Flower SuperExec",
|
|
76
|
+
)
|
|
77
|
+
parser.add_argument(
|
|
78
|
+
"executor",
|
|
79
|
+
help="For example: `deployment:exec` or `project.package.module:wrapper.exec`.",
|
|
80
|
+
)
|
|
81
|
+
parser.add_argument(
|
|
82
|
+
"--address",
|
|
83
|
+
help="SuperExec (gRPC) server address (IPv4, IPv6, or a domain name)",
|
|
84
|
+
default=SUPEREXEC_DEFAULT_ADDRESS,
|
|
85
|
+
)
|
|
86
|
+
parser.add_argument(
|
|
87
|
+
"--executor-dir",
|
|
88
|
+
help="The directory for the executor.",
|
|
89
|
+
default=".",
|
|
90
|
+
)
|
|
91
|
+
parser.add_argument(
|
|
92
|
+
"--insecure",
|
|
93
|
+
action="store_true",
|
|
94
|
+
help="Run the SuperExec without HTTPS, regardless of whether certificate "
|
|
95
|
+
"paths are provided. By default, the server runs with HTTPS enabled. "
|
|
96
|
+
"Use this flag only if you understand the risks.",
|
|
97
|
+
)
|
|
98
|
+
parser.add_argument(
|
|
99
|
+
"--ssl-certfile",
|
|
100
|
+
help="SuperExec server SSL certificate file (as a path str) "
|
|
101
|
+
"to create a secure connection.",
|
|
102
|
+
type=str,
|
|
103
|
+
default=None,
|
|
104
|
+
)
|
|
105
|
+
parser.add_argument(
|
|
106
|
+
"--ssl-keyfile",
|
|
107
|
+
help="SuperExec server SSL private key file (as a path str) "
|
|
108
|
+
"to create a secure connection.",
|
|
109
|
+
type=str,
|
|
110
|
+
)
|
|
111
|
+
parser.add_argument(
|
|
112
|
+
"--ssl-ca-certfile",
|
|
113
|
+
help="SuperExec server SSL CA certificate file (as a path str) "
|
|
114
|
+
"to create a secure connection.",
|
|
115
|
+
type=str,
|
|
116
|
+
)
|
|
117
|
+
return parser
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _try_obtain_certificates(
|
|
121
|
+
args: argparse.Namespace,
|
|
122
|
+
) -> Optional[Tuple[bytes, bytes, bytes]]:
|
|
123
|
+
# Obtain certificates
|
|
124
|
+
if args.insecure:
|
|
125
|
+
log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
|
|
126
|
+
return None
|
|
127
|
+
# Check if certificates are provided
|
|
128
|
+
if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
|
|
129
|
+
if not Path.is_file(args.ssl_ca_certfile):
|
|
130
|
+
sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
|
|
131
|
+
if not Path.is_file(args.ssl_certfile):
|
|
132
|
+
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
|
|
133
|
+
if not Path.is_file(args.ssl_keyfile):
|
|
134
|
+
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
|
|
135
|
+
certificates = (
|
|
136
|
+
Path(args.ssl_ca_certfile).read_bytes(), # CA certificate
|
|
137
|
+
Path(args.ssl_certfile).read_bytes(), # server certificate
|
|
138
|
+
Path(args.ssl_keyfile).read_bytes(), # server private key
|
|
139
|
+
)
|
|
140
|
+
return certificates
|
|
141
|
+
if args.ssl_certfile or args.ssl_keyfile or args.ssl_ca_certfile:
|
|
142
|
+
sys.exit(
|
|
143
|
+
"You need to provide valid file paths to `--ssl-certfile`, "
|
|
144
|
+
"`--ssl-keyfile`, and `—-ssl-ca-certfile` to create a secure "
|
|
145
|
+
"connection in SuperExec server (gRPC-rere)."
|
|
146
|
+
)
|
|
147
|
+
sys.exit(
|
|
148
|
+
"Certificates are required unless running in insecure mode. "
|
|
149
|
+
"Please provide certificate paths to `--ssl-certfile`, "
|
|
150
|
+
"`--ssl-keyfile`, and `—-ssl-ca-certfile` or run the server "
|
|
151
|
+
"in insecure mode using '--insecure' if you understand the risks."
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _load_executor(
|
|
156
|
+
args: argparse.Namespace,
|
|
157
|
+
) -> Executor:
|
|
158
|
+
"""Get the executor plugin."""
|
|
159
|
+
if args.executor_dir is not None:
|
|
160
|
+
sys.path.insert(0, args.executor_dir)
|
|
161
|
+
|
|
162
|
+
executor_ref: str = args.executor
|
|
163
|
+
valid, error_msg = validate(executor_ref)
|
|
164
|
+
if not valid and error_msg:
|
|
165
|
+
raise LoadExecutorError(error_msg) from None
|
|
166
|
+
|
|
167
|
+
executor = load_app(executor_ref, LoadExecutorError, args.executor_dir)
|
|
168
|
+
|
|
169
|
+
if not isinstance(executor, Executor):
|
|
170
|
+
raise LoadExecutorError(
|
|
171
|
+
f"Attribute {executor_ref} is not of type {Executor}",
|
|
172
|
+
) from None
|
|
173
|
+
|
|
174
|
+
return executor
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class LoadExecutorError(Exception):
|
|
178
|
+
"""Error when trying to load `Executor`."""
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""SuperExec gRPC API."""
|
|
16
|
+
|
|
17
|
+
from logging import INFO
|
|
18
|
+
from typing import Optional, Tuple
|
|
19
|
+
|
|
20
|
+
import grpc
|
|
21
|
+
|
|
22
|
+
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
23
|
+
from flwr.common.logger import log
|
|
24
|
+
from flwr.proto.exec_pb2_grpc import add_ExecServicer_to_server
|
|
25
|
+
from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
|
|
26
|
+
|
|
27
|
+
from .exec_servicer import ExecServicer
|
|
28
|
+
from .executor import Executor
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def run_superexec_api_grpc(
|
|
32
|
+
address: str,
|
|
33
|
+
executor: Executor,
|
|
34
|
+
certificates: Optional[Tuple[bytes, bytes, bytes]],
|
|
35
|
+
) -> grpc.Server:
|
|
36
|
+
"""Run SuperExec API (gRPC, request-response)."""
|
|
37
|
+
exec_servicer: grpc.Server = ExecServicer(
|
|
38
|
+
executor=executor,
|
|
39
|
+
)
|
|
40
|
+
superexec_add_servicer_to_server_fn = add_ExecServicer_to_server
|
|
41
|
+
superexec_grpc_server = generic_create_grpc_server(
|
|
42
|
+
servicer_and_add_fn=(exec_servicer, superexec_add_servicer_to_server_fn),
|
|
43
|
+
server_address=address,
|
|
44
|
+
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
45
|
+
certificates=certificates,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
log(INFO, "Flower ECE: Starting SuperExec API (gRPC-rere) on %s", address)
|
|
49
|
+
superexec_grpc_server.start()
|
|
50
|
+
|
|
51
|
+
return superexec_grpc_server
|